diff --git a/agentrun/__init__.py b/agentrun/__init__.py index bf3e750..9693b1b 100644 --- a/agentrun/__init__.py +++ b/agentrun/__init__.py @@ -113,6 +113,21 @@ SandboxClient, Template, ) +# Super Agent +from agentrun.super_agent import ( + ConversationInfo, + InvokeResponseData, + InvokeStream, +) +from agentrun.super_agent import Message as SuperAgentMessage +from agentrun.super_agent import ( + SSEEvent, + SuperAgent, + SuperAgentClient, + SuperAgentCreateInput, + SuperAgentListInput, + SuperAgentUpdateInput, +) # Tool from agentrun.tool import Tool as ToolResource from agentrun.tool import ToolClient as ToolResourceClient @@ -248,6 +263,20 @@ "ModelProxyCreateInput", "ModelProxyUpdateInput", "ModelProxyListInput", + ######## Super Agent ######## + # base + "SuperAgent", + "SuperAgentClient", + # inner model + "InvokeStream", + "SSEEvent", + "ConversationInfo", + "SuperAgentMessage", + # api model + "SuperAgentCreateInput", + "SuperAgentUpdateInput", + "SuperAgentListInput", + "InvokeResponseData", ######## Sandbox ######## "SandboxClient", "BrowserSandbox", diff --git a/agentrun/agent_runtime/__client_async_template.py b/agentrun/agent_runtime/__client_async_template.py index ac8759a..d72bac6 100644 --- a/agentrun/agent_runtime/__client_async_template.py +++ b/agentrun/agent_runtime/__client_async_template.py @@ -118,7 +118,12 @@ async def delete_async( result = await self.__control_api.delete_agent_runtime_async( id, config=config ) - return AgentRuntime.from_inner_object(result) + # Delete 响应只有 agentRuntimeId 有效,其他字段为空字符串/零值, + # 走 from_inner_object 会在 status/artifactType 等 Enum 字段上触发 + # 伪校验错误。这里直接构造一个只带 id 的极简对象。 + return AgentRuntime.model_construct( + agent_runtime_id=result.agent_runtime_id + ) except HTTPError as e: raise e.to_resource_error("AgentRuntime", id) from e diff --git a/agentrun/agent_runtime/api/control.py b/agentrun/agent_runtime/api/control.py index 8b76d45..c3a3847 100644 --- a/agentrun/agent_runtime/api/control.py +++ b/agentrun/agent_runtime/api/control.py @@ -21,6 +21,9 @@ CreateAgentRuntimeEndpointRequest, CreateAgentRuntimeInput, CreateAgentRuntimeRequest, + DeleteAgentRuntimeEndpointRequest, + DeleteAgentRuntimeRequest, + GetAgentRuntimeEndpointRequest, GetAgentRuntimeRequest, ListAgentRuntimeEndpointsOutput, ListAgentRuntimeEndpointsRequest, @@ -193,6 +196,7 @@ def delete_agent_runtime( client = self._get_client(config) response = client.delete_agent_runtime_with_options( agent_id, + DeleteAgentRuntimeRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -248,6 +252,7 @@ async def delete_agent_runtime_async( client = self._get_client(config) response = await client.delete_agent_runtime_with_options_async( agent_id, + DeleteAgentRuntimeRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -778,6 +783,7 @@ def delete_agent_runtime_endpoint( response = client.delete_agent_runtime_endpoint_with_options( agent_id, endpoint_id, + DeleteAgentRuntimeEndpointRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -838,6 +844,7 @@ async def delete_agent_runtime_endpoint_async( await client.delete_agent_runtime_endpoint_with_options_async( agent_id, endpoint_id, + DeleteAgentRuntimeEndpointRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -1028,6 +1035,7 @@ def get_agent_runtime_endpoint( response = client.get_agent_runtime_endpoint_with_options( agent_id, endpoint_id, + GetAgentRuntimeEndpointRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) @@ -1088,6 +1096,7 @@ async def get_agent_runtime_endpoint_async( await client.get_agent_runtime_endpoint_with_options_async( agent_id, endpoint_id, + GetAgentRuntimeEndpointRequest(), headers=headers or {}, runtime=RuntimeOptions(), ) diff --git a/agentrun/agent_runtime/client.py b/agentrun/agent_runtime/client.py index 826e047..6010463 100644 --- a/agentrun/agent_runtime/client.py +++ b/agentrun/agent_runtime/client.py @@ -171,7 +171,12 @@ async def delete_async( result = await self.__control_api.delete_agent_runtime_async( id, config=config ) - return AgentRuntime.from_inner_object(result) + # Delete 响应只有 agentRuntimeId 有效,其他字段为空字符串/零值, + # 走 from_inner_object 会在 status/artifactType 等 Enum 字段上触发 + # 伪校验错误。这里直接构造一个只带 id 的极简对象。 + return AgentRuntime.model_construct( + agent_runtime_id=result.agent_runtime_id + ) except HTTPError as e: raise e.to_resource_error("AgentRuntime", id) from e @@ -191,7 +196,12 @@ def delete(self, id: str, config: Optional[Config] = None) -> AgentRuntime: """ try: result = self.__control_api.delete_agent_runtime(id, config=config) - return AgentRuntime.from_inner_object(result) + # Delete 响应只有 agentRuntimeId 有效,其他字段为空字符串/零值, + # 走 from_inner_object 会在 status/artifactType 等 Enum 字段上触发 + # 伪校验错误。这里直接构造一个只带 id 的极简对象。 + return AgentRuntime.model_construct( + agent_runtime_id=result.agent_runtime_id + ) except HTTPError as e: raise e.to_resource_error("AgentRuntime", id) from e diff --git a/agentrun/agent_runtime/model.py b/agentrun/agent_runtime/model.py index f3d649e..a169ef3 100644 --- a/agentrun/agent_runtime/model.py +++ b/agentrun/agent_runtime/model.py @@ -185,6 +185,7 @@ class AgentRuntimeProtocolType(str, Enum): HTTP = "HTTP" MCP = "MCP" + SUPER_AGENT = "SUPER_AGENT" class AgentRuntimeProtocolConfig(BaseModel): @@ -251,6 +252,8 @@ class AgentRuntimeMutableProps(BaseModel): """会话空闲超时时间,单位:秒""" tags: Optional[List[str]] = None """标签列表""" + system_tags: Optional[List[str]] = None + """系统标签列表 (由平台内部使用, 例如 SuperAgent 用来标识下游 AgentRuntime)""" class AgentRuntimeImmutableProps(BaseModel): @@ -322,6 +325,8 @@ class AgentRuntimeListInput(PageableInput): """Agent Runtime 名称""" tags: Optional[str] = None """标签过滤,多个标签用逗号分隔""" + system_tags: Optional[str] = None + """系统标签过滤, 多个标签用逗号分隔""" search_mode: Optional[str] = None """搜索模式""" diff --git a/agentrun/integration/utils/tool.py b/agentrun/integration/utils/tool.py index 1fd3666..c479846 100644 --- a/agentrun/integration/utils/tool.py +++ b/agentrun/integration/utils/tool.py @@ -1614,8 +1614,8 @@ def _build_openapi_schema( if isinstance(schema, dict): properties[name] = { **schema, - "description": ( - param.get("description") or schema.get("description", "") + "description": param.get("description") or schema.get( + "description", "" ), } if param.get("required"): diff --git a/agentrun/super_agent/__agent_async_template.py b/agentrun/super_agent/__agent_async_template.py new file mode 100644 index 0000000..9f5ba8b --- /dev/null +++ b/agentrun/super_agent/__agent_async_template.py @@ -0,0 +1,230 @@ +"""SuperAgent 实例 / Super Agent Instance + +``SuperAgent`` 是暴露给应用开发者的强类型实例对象, 承载 ``invoke`` / 会话管理 +两类方法 (仅异步; 见决策 14)。CRUDL 由 ``SuperAgentClient`` 管理。 + +本文件为模板 (``__agent_async_template.py``), codegen 会把 ``async def ...`` +转换成同步骨架; 实际第一版异步主路径 + 同步 NotImplementedError 占位。 +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from agentrun.super_agent.api.data import SuperAgentDataAPI +from agentrun.super_agent.model import ConversationInfo, Message +from agentrun.super_agent.stream import InvokeStream +from agentrun.utils.config import Config + +_SYNC_UNSUPPORTED_MSG = ( + "sync version not supported, use *_async (see decision 14 in" + " openspec/changes/add-super-agent-sdk/design.md)" +) + + +@dataclass +class SuperAgent: + """超级 Agent 实例. + + 业务字段 (``prompt / agents / tools / ...``) 从 ``protocolSettings.config`` + 反解。系统字段 (``agent_runtime_id / arn / status / ...``) 来自 AgentRuntime。 + """ + + name: str + description: Optional[str] = None + prompt: Optional[str] = None + agents: List[str] = field(default_factory=list) + tools: List[str] = field(default_factory=list) + skills: List[str] = field(default_factory=list) + sandboxes: List[str] = field(default_factory=list) + workspaces: List[str] = field(default_factory=list) + model_service_name: Optional[str] = None + model_name: Optional[str] = None + + agent_runtime_id: str = "" + arn: str = "" + status: str = "" + created_at: str = "" + last_updated_at: str = "" + external_endpoint: str = "" + + _client: Any = field(default=None, repr=False, compare=False) + + def _resolve_config(self, config: Optional[Config]) -> Config: + client_cfg = ( + getattr(self._client, "config", None) if self._client else None + ) + return Config.with_configs(client_cfg, config) + + def _forwarded_business_fields(self) -> Dict[str, Any]: + """把 SuperAgent 实例字段打包成 ``forwardedProps`` 顶层业务字段 dict. + + 与 ``protocolSettings[0].config`` 写入时的结构保持对称: list 型用 ``[]`` + 代替 None, scalar 型保留 None (由 JSON 序列化为 ``null``)。服务端读取同 + 一份语义, 避免客户端/服务端对"未设置"产生歧义。 + """ + return { + "prompt": self.prompt, + "agents": list(self.agents), + "tools": list(self.tools), + "skills": list(self.skills), + "sandboxes": list(self.sandboxes), + "workspaces": list(self.workspaces), + "modelServiceName": self.model_service_name, + "modelName": self.model_name, + } + + async def invoke_async( + self, + messages: List[Dict[str, Any]], + *, + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + ) -> InvokeStream: + """Phase 1: POST /invoke; 返回包含 ``conversation_id`` 的 :class:`InvokeStream`. + + 首次 ``async for ev in stream`` 才触发 Phase 2 拉流 (lazy)。 + """ + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + resp = await api.invoke_async( + messages, + conversation_id=conversation_id, + config=cfg, + forwarded_extras=self._forwarded_business_fields(), + ) + stream_url = resp.stream_url + stream_headers = dict(resp.stream_headers) + session_id = stream_headers.get("X-Super-Agent-Session-Id", "") + + async def _factory(): + return api.stream_async( + stream_url, stream_headers=stream_headers, config=cfg + ) + + return InvokeStream( + conversation_id=resp.conversation_id, + session_id=session_id, + stream_url=stream_url, + stream_headers=stream_headers, + _stream_factory=_factory, + ) + + def invoke( + self, + messages: List[Dict[str, Any]], + *, + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + ) -> InvokeStream: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def list_conversations_async( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[ConversationInfo]: + """GET /conversations → ``List[ConversationInfo]``. + + 默认按 ``{"agentRuntimeName": self.name}`` 过滤 (仅返回当前 agent 的会话); + 传入 ``metadata={}`` 或自定义 dict 可覆盖默认过滤条件。 + """ + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + effective_metadata = ( + {"agentRuntimeName": self.name} if metadata is None else metadata + ) + raw_list = await api.list_conversations_async( + metadata=effective_metadata, config=cfg + ) + return [ + _conversation_info_from_dict( + raw, + fallback_conversation_id=str(raw.get("conversationId") or ""), + ) + for raw in raw_list + ] + + def list_conversations( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[ConversationInfo]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def get_conversation_async( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> ConversationInfo: + """GET /conversations/{id} → :class:`ConversationInfo` (缺字段用默认值).""" + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + data = await api.get_conversation_async(conversation_id, config=cfg) + return _conversation_info_from_dict( + data, fallback_conversation_id=conversation_id + ) + + def get_conversation( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> ConversationInfo: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def delete_conversation_async( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> None: + """DELETE /conversations/{id}.""" + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + await api.delete_conversation_async(conversation_id, config=cfg) + + def delete_conversation( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> None: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + +def _to_message(raw: Dict[str, Any]) -> Message: + return Message( + role=str(raw.get("role") or ""), + content=str(raw.get("content") or ""), + message_id=raw.get("messageId") or raw.get("message_id"), + created_at=raw.get("createdAt") or raw.get("created_at"), + ) + + +def _conversation_info_from_dict( + data: Dict[str, Any], *, fallback_conversation_id: str +) -> ConversationInfo: + data = data or {} + messages_raw = data.get("messages") or [] + messages = [_to_message(m) for m in messages_raw if isinstance(m, dict)] + return ConversationInfo( + conversation_id=str( + data.get("conversationId") or fallback_conversation_id + ), + agent_id=str(data.get("agentId") or data.get("agent_id") or ""), + title=data.get("title"), + main_user_id=data.get("mainUserId") or data.get("main_user_id"), + sub_user_id=data.get("subUserId") or data.get("sub_user_id"), + created_at=int(data.get("createdAt") or data.get("created_at") or 0), + updated_at=int(data.get("updatedAt") or data.get("updated_at") or 0), + error_message=data.get("errorMessage") or data.get("error_message"), + invoke_info=data.get("invokeInfo") or data.get("invoke_info"), + messages=messages, + params=data.get("params"), + ) + + +__all__ = ["SuperAgent"] diff --git a/agentrun/super_agent/__client_async_template.py b/agentrun/super_agent/__client_async_template.py new file mode 100644 index 0000000..51b5723 --- /dev/null +++ b/agentrun/super_agent/__client_async_template.py @@ -0,0 +1,521 @@ +"""SuperAgentClient / 超级 Agent 客户端 + +对外入口: CRUDL (create / get / update / delete / list / list_all) 同步 + 异步双写。 +内部持有一个 :class:`AgentRuntimeClient` 实例, 通过 ``api/control.py`` 的 +转换函数把 ``SuperAgent`` 与 ``AgentRuntime`` 互相映射。 + +list 固定按 systemTag ``x-agentrun-super`` 过滤, 不接受用户自定义 tag。 +""" + +import asyncio +import time +from typing import Any, List, Optional + +from alibabacloud_agentrun20250910.models import ( + CreateAgentRuntimeInput, + UpdateAgentRuntimeInput, +) + +from agentrun.agent_runtime.api import AgentRuntimeControlAPI +from agentrun.agent_runtime.client import AgentRuntimeClient +from agentrun.agent_runtime.model import AgentRuntimeListInput +from agentrun.agent_runtime.runtime import AgentRuntime +from agentrun.super_agent.agent import SuperAgent +from agentrun.super_agent.api.control import ( + ensure_super_agent_patches_applied, + from_agent_runtime, + is_super_agent, + SUPER_AGENT_TAG, + to_create_input, + to_update_input, +) +from agentrun.utils.config import Config +from agentrun.utils.log import logger +from agentrun.utils.model import Status + +# 公开 API 签名故意保持 ``Optional[X] = None`` 对外简洁; +# ``_UNSET`` 仅用于内部区分 "未传" 与 "显式 None (= 清空)". +_UNSET: Any = object() + +# create/update 轮询默认参数 +_WAIT_INTERVAL_SECONDS = 3 +_WAIT_TIMEOUT_SECONDS = 300 + + +def _raise_if_failed(rt: AgentRuntime, action: str) -> None: + """若 rt 处于失败态, 抛出带 status_reason 的 RuntimeError.""" + status = getattr(rt, "status", None) + status_str = str(status) if status is not None else "" + if status_str in { + Status.CREATE_FAILED.value, + Status.UPDATE_FAILED.value, + Status.DELETE_FAILED.value, + }: + reason = getattr(rt, "status_reason", None) or "(no reason)" + name = getattr(rt, "agent_runtime_name", None) or "(unknown)" + raise RuntimeError( + f"Super agent {action} failed: name={name!r} status={status_str} " + f"reason={reason}" + ) + + +def _merge(current: dict, updates: dict) -> dict: + """把 ``updates`` 中非 ``_UNSET`` 的字段合并到 ``current`` (None 表示清空).""" + merged = dict(current) + for key, value in updates.items(): + if value is _UNSET: + continue + merged[key] = value + return merged + + +def _super_agent_to_business_dict(agent: SuperAgent) -> dict: + return { + "description": agent.description, + "prompt": agent.prompt, + "agents": list(agent.agents), + "tools": list(agent.tools), + "skills": list(agent.skills), + "sandboxes": list(agent.sandboxes), + "workspaces": list(agent.workspaces), + "model_service_name": agent.model_service_name, + "model_name": agent.model_name, + } + + +class SuperAgentClient: + """Super Agent CRUDL 客户端.""" + + def __init__(self, config: Optional[Config] = None) -> None: + # 按需打 Dara SDK 兼容补丁 (幂等)。放在本构造函数里, 让 "仅 import + # agentrun.super_agent" 的调用方不被动承担全局 SDK 副作用。 + ensure_super_agent_patches_applied() + self.config = config + self._rt = AgentRuntimeClient(config=config) + # create/update 绕过 AgentRuntimeClient 的 artifact_type 校验 (SUPER_AGENT 不需要 code/container), + # 并通过 ``ProtocolConfiguration`` 的 monkey-patch 保留 ``externalEndpoint`` 字段。 + self._rt_control = AgentRuntimeControlAPI(config=config) + + async def _wait_final_async( + self, + agent_runtime_id: str, + *, + config: Optional[Config] = None, + interval_seconds: int = _WAIT_INTERVAL_SECONDS, + timeout_seconds: int = _WAIT_TIMEOUT_SECONDS, + ) -> AgentRuntime: + """轮询 get 直到 status 进入最终态 (READY / *_FAILED).""" + cfg = Config.with_configs(self.config, config) + start = time.monotonic() + while True: + rt = await self._rt.get_async(agent_runtime_id, config=cfg) + status = getattr(rt, "status", None) + logger.debug( + "super agent %s poll status=%s", agent_runtime_id, status + ) + if Status.is_final_status(status): + return rt + if time.monotonic() - start > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for super agent {agent_runtime_id!r}" + f" to reach final status (last status={status})" + ) + await asyncio.sleep(interval_seconds) + + def _wait_final( + self, + agent_runtime_id: str, + *, + config: Optional[Config] = None, + interval_seconds: int = _WAIT_INTERVAL_SECONDS, + timeout_seconds: int = _WAIT_TIMEOUT_SECONDS, + ) -> AgentRuntime: + """同步版 _wait_final_async.""" + cfg = Config.with_configs(self.config, config) + start = time.monotonic() + while True: + rt = self._rt.get(agent_runtime_id, config=cfg) + status = getattr(rt, "status", None) + logger.debug( + "super agent %s poll status=%s", agent_runtime_id, status + ) + if Status.is_final_status(status): + return rt + if time.monotonic() - start > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for super agent {agent_runtime_id!r}" + f" to reach final status (last status={status})" + ) + time.sleep(interval_seconds) + + # ─── Create ────────────────────────────────────── + async def create_async( + self, + *, + name: str, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> SuperAgent: + """异步创建超级 Agent.""" + cfg = Config.with_configs(self.config, config) + rt_input = to_create_input( + name, + description=description, + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + cfg=cfg, + ) + dara_input = CreateAgentRuntimeInput().from_map(rt_input.model_dump()) + result = await self._rt_control.create_agent_runtime_async( + dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + # 轮询直到进入最终态; 失败则抛出带 status_reason 的错误。 + agent_id = getattr(rt, "agent_runtime_id", None) + if agent_id: + rt = await self._wait_final_async(agent_id, config=cfg) + _raise_if_failed(rt, action="create") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def create( + self, + *, + name: str, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> SuperAgent: + """同步创建超级 Agent.""" + cfg = Config.with_configs(self.config, config) + rt_input = to_create_input( + name, + description=description, + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + cfg=cfg, + ) + dara_input = CreateAgentRuntimeInput().from_map(rt_input.model_dump()) + result = self._rt_control.create_agent_runtime(dara_input, config=cfg) + rt = AgentRuntime.from_inner_object(result) + agent_id = getattr(rt, "agent_runtime_id", None) + if agent_id: + rt = self._wait_final(agent_id, config=cfg) + _raise_if_failed(rt, action="create") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Get ────────────────────────────────────────── + # Aliyun 控制面 get/delete/update 接口只认 ``agent_runtime_id`` (URN), + # 不认 resource_name; ``_find_rt_by_name*`` 用 list + 名称匹配来解析 id. + def _find_rt_by_name(self, name: str, config: Optional[Config]) -> Any: + cfg = Config.with_configs(self.config, config) + page_number = 1 + page_size = 50 + while True: + runtimes = self._rt.list( + AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + system_tags=SUPER_AGENT_TAG, + ), + config=cfg, + ) + for rt in runtimes: + if getattr(rt, "agent_runtime_name", None) == name: + return rt + if len(runtimes) < page_size: + return None + page_number += 1 + + async def _find_rt_by_name_async( + self, name: str, config: Optional[Config] + ) -> Any: + cfg = Config.with_configs(self.config, config) + page_number = 1 + page_size = 50 + while True: + runtimes = await self._rt.list_async( + AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + system_tags=SUPER_AGENT_TAG, + ), + config=cfg, + ) + for rt in runtimes: + if getattr(rt, "agent_runtime_name", None) == name: + return rt + if len(runtimes) < page_size: + return None + page_number += 1 + + async def get_async( + self, name: str, *, config: Optional[Config] = None + ) -> SuperAgent: + """异步获取超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def get(self, name: str, *, config: Optional[Config] = None) -> SuperAgent: + """同步获取超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Update (read-merge-write) ───────────────────── + # 参数默认值 ``_UNSET`` 是内部哨兵 (object())。为保留 IDE 自动补全与 mypy + # 类型检查, 签名保持精确类型标注, 对 ``= _UNSET`` 的赋值加 ``type: ignore``。 + # 未传 = 保持不变, 显式传 None = 清空字段。 + async def update_async( + self, + name: str, + *, + description: Optional[str] = _UNSET, # type: ignore[assignment] + prompt: Optional[str] = _UNSET, # type: ignore[assignment] + agents: Optional[List[str]] = _UNSET, # type: ignore[assignment] + tools: Optional[List[str]] = _UNSET, # type: ignore[assignment] + skills: Optional[List[str]] = _UNSET, # type: ignore[assignment] + sandboxes: Optional[List[str]] = _UNSET, # type: ignore[assignment] + workspaces: Optional[List[str]] = _UNSET, # type: ignore[assignment] + model_service_name: Optional[str] = _UNSET, # type: ignore[assignment] + model_name: Optional[str] = _UNSET, # type: ignore[assignment] + config: Optional[Config] = None, + ) -> SuperAgent: + """异步更新超级 Agent (read-merge-write).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + current = _super_agent_to_business_dict(from_agent_runtime(rt)) + updates = { + "description": description, + "prompt": prompt, + "agents": agents, + "tools": tools, + "skills": skills, + "sandboxes": sandboxes, + "workspaces": workspaces, + "model_service_name": model_service_name, + "model_name": model_name, + } + merged = _merge(current, updates) + rt_input = to_update_input(name, merged, cfg=cfg) + dara_input = UpdateAgentRuntimeInput().from_map(rt_input.model_dump()) + agent_id = getattr(rt, "agent_runtime_id", None) or name + result = await self._rt_control.update_agent_runtime_async( + agent_id, dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + rt_id = getattr(rt, "agent_runtime_id", None) or agent_id + if rt_id: + rt = await self._wait_final_async(rt_id, config=cfg) + _raise_if_failed(rt, action="update") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def update( + self, + name: str, + *, + description: Optional[str] = _UNSET, # type: ignore[assignment] + prompt: Optional[str] = _UNSET, # type: ignore[assignment] + agents: Optional[List[str]] = _UNSET, # type: ignore[assignment] + tools: Optional[List[str]] = _UNSET, # type: ignore[assignment] + skills: Optional[List[str]] = _UNSET, # type: ignore[assignment] + sandboxes: Optional[List[str]] = _UNSET, # type: ignore[assignment] + workspaces: Optional[List[str]] = _UNSET, # type: ignore[assignment] + model_service_name: Optional[str] = _UNSET, # type: ignore[assignment] + model_name: Optional[str] = _UNSET, # type: ignore[assignment] + config: Optional[Config] = None, + ) -> SuperAgent: + """同步更新超级 Agent (read-merge-write).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + current = _super_agent_to_business_dict(from_agent_runtime(rt)) + updates = { + "description": description, + "prompt": prompt, + "agents": agents, + "tools": tools, + "skills": skills, + "sandboxes": sandboxes, + "workspaces": workspaces, + "model_service_name": model_service_name, + "model_name": model_name, + } + merged = _merge(current, updates) + rt_input = to_update_input(name, merged, cfg=cfg) + dara_input = UpdateAgentRuntimeInput().from_map(rt_input.model_dump()) + agent_id = getattr(rt, "agent_runtime_id", None) or name + result = self._rt_control.update_agent_runtime( + agent_id, dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + rt_id = getattr(rt, "agent_runtime_id", None) or agent_id + if rt_id: + rt = self._wait_final(rt_id, config=cfg) + _raise_if_failed(rt, action="update") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Delete ─────────────────────────────────────── + async def delete_async( + self, name: str, *, config: Optional[Config] = None + ) -> None: + """异步删除超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + agent_id = getattr(rt, "agent_runtime_id", None) or name + await self._rt.delete_async(agent_id, config=cfg) + + def delete(self, name: str, *, config: Optional[Config] = None) -> None: + """同步删除超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + agent_id = getattr(rt, "agent_runtime_id", None) or name + self._rt.delete(agent_id, config=cfg) + + # ─── List ───────────────────────────────────────── + async def list_async( + self, + *, + page_number: int = 1, + page_size: int = 20, + config: Optional[Config] = None, + ) -> List[SuperAgent]: + """异步列出超级 Agent (固定 systemTag 过滤, 过滤非 SUPER_AGENT).""" + cfg = Config.with_configs(self.config, config) + rt_input = AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + system_tags=SUPER_AGENT_TAG, + ) + runtimes = await self._rt.list_async(rt_input, config=cfg) + result: List[SuperAgent] = [] + for rt in runtimes: + if not is_super_agent(rt): + continue + agent = from_agent_runtime(rt) + agent._client = self + result.append(agent) + return result + + def list( + self, + *, + page_number: int = 1, + page_size: int = 20, + config: Optional[Config] = None, + ) -> List[SuperAgent]: + """同步列出超级 Agent (固定 systemTag 过滤, 过滤非 SUPER_AGENT).""" + cfg = Config.with_configs(self.config, config) + rt_input = AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + system_tags=SUPER_AGENT_TAG, + ) + runtimes = self._rt.list(rt_input, config=cfg) + result: List[SuperAgent] = [] + for rt in runtimes: + if not is_super_agent(rt): + continue + agent = from_agent_runtime(rt) + agent._client = self + result.append(agent) + return result + + async def list_all_async( + self, *, config: Optional[Config] = None, page_size: int = 50 + ) -> List[SuperAgent]: + """异步一次性拉取所有超级 Agent (自动分页).""" + cfg = Config.with_configs(self.config, config) + result: List[SuperAgent] = [] + page_number = 1 + while True: + page = await self.list_async( + page_number=page_number, page_size=page_size, config=cfg + ) + if not page: + break + result.extend(page) + if len(page) < page_size: + break + page_number += 1 + return result + + def list_all( + self, *, config: Optional[Config] = None, page_size: int = 50 + ) -> List[SuperAgent]: + """同步一次性拉取所有超级 Agent (自动分页).""" + cfg = Config.with_configs(self.config, config) + result: List[SuperAgent] = [] + page_number = 1 + while True: + page = self.list( + page_number=page_number, page_size=page_size, config=cfg + ) + if not page: + break + result.extend(page) + if len(page) < page_size: + break + page_number += 1 + return result + + +__all__ = ["SuperAgentClient"] diff --git a/agentrun/super_agent/__init__.py b/agentrun/super_agent/__init__.py new file mode 100644 index 0000000..5afb68e --- /dev/null +++ b/agentrun/super_agent/__init__.py @@ -0,0 +1,52 @@ +"""Super Agent 模块 / Super Agent Module + +独立的超级 Agent SDK, 面向写应用的开发者。用户无需感知底层 AgentRuntime 概念, +只需: + +.. code-block:: python + + from agentrun.super_agent import SuperAgentClient + + client = SuperAgentClient() + agent = await client.create_async(name="my-agent", prompt="你好") + stream = await agent.invoke_async(messages=[{"role": "user", "content": "hi"}]) + print(stream.conversation_id) + async for ev in stream: + print(ev.event, ev.data) + +可选的 AG-UI 强类型适配, 显式导入: + +.. code-block:: python + + from agentrun.super_agent.agui import as_agui_events + + async for event in as_agui_events(stream): + ... # event 是 ``ag_ui.core.BaseEvent`` 子类实例 + +详见 ``openspec/changes/add-super-agent-sdk/`` 中的 proposal、design、spec。 +""" + +from .agent import SuperAgent +from .client import SuperAgentClient +from .model import ( + ConversationInfo, + InvokeResponseData, + Message, + SuperAgentCreateInput, + SuperAgentListInput, + SuperAgentUpdateInput, +) +from .stream import InvokeStream, SSEEvent + +__all__ = [ + "SuperAgentClient", + "SuperAgent", + "InvokeStream", + "SSEEvent", + "SuperAgentCreateInput", + "SuperAgentUpdateInput", + "SuperAgentListInput", + "ConversationInfo", + "Message", + "InvokeResponseData", +] diff --git a/agentrun/super_agent/agent.py b/agentrun/super_agent/agent.py new file mode 100644 index 0000000..2460041 --- /dev/null +++ b/agentrun/super_agent/agent.py @@ -0,0 +1,238 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/super_agent/__agent_async_template.py + +SuperAgent 实例 / Super Agent Instance + +``SuperAgent`` 是暴露给应用开发者的强类型实例对象, 承载 ``invoke`` / 会话管理 +两类方法 (仅异步; 见决策 14)。CRUDL 由 ``SuperAgentClient`` 管理。 +同步方法保留 ``NotImplementedError`` 占位, 以备未来扩展。 +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from agentrun.super_agent.api.data import SuperAgentDataAPI +from agentrun.super_agent.model import ConversationInfo, Message +from agentrun.super_agent.stream import InvokeStream +from agentrun.utils.config import Config + +_SYNC_UNSUPPORTED_MSG = ( + "sync version not supported, use *_async (see decision 14 in" + " openspec/changes/add-super-agent-sdk/design.md)" +) + + +@dataclass +class SuperAgent: + """超级 Agent 实例. + + 业务字段 (``prompt / agents / tools / ...``) 从 ``protocolSettings.config`` + 反解。系统字段 (``agent_runtime_id / arn / status / ...``) 来自 AgentRuntime。 + """ + + name: str + description: Optional[str] = None + prompt: Optional[str] = None + agents: List[str] = field(default_factory=list) + tools: List[str] = field(default_factory=list) + skills: List[str] = field(default_factory=list) + sandboxes: List[str] = field(default_factory=list) + workspaces: List[str] = field(default_factory=list) + model_service_name: Optional[str] = None + model_name: Optional[str] = None + + agent_runtime_id: str = "" + arn: str = "" + status: str = "" + created_at: str = "" + last_updated_at: str = "" + external_endpoint: str = "" + + _client: Any = field(default=None, repr=False, compare=False) + + def _resolve_config(self, config: Optional[Config]) -> Config: + client_cfg = ( + getattr(self._client, "config", None) if self._client else None + ) + return Config.with_configs(client_cfg, config) + + def _forwarded_business_fields(self) -> Dict[str, Any]: + """把 SuperAgent 实例字段打包成 ``forwardedProps`` 顶层业务字段 dict. + + 与 ``protocolSettings[0].config`` 写入时的结构保持对称: list 型用 ``[]`` + 代替 None, scalar 型保留 None (由 JSON 序列化为 ``null``)。服务端读取同 + 一份语义, 避免客户端/服务端对"未设置"产生歧义。 + """ + return { + "prompt": self.prompt, + "agents": list(self.agents), + "tools": list(self.tools), + "skills": list(self.skills), + "sandboxes": list(self.sandboxes), + "workspaces": list(self.workspaces), + "modelServiceName": self.model_service_name, + "modelName": self.model_name, + } + + async def invoke_async( + self, + messages: List[Dict[str, Any]], + *, + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + ) -> InvokeStream: + """Phase 1: POST /invoke; 返回包含 ``conversation_id`` 的 :class:`InvokeStream`. + + 首次 ``async for ev in stream`` 才触发 Phase 2 拉流 (lazy)。 + """ + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + resp = await api.invoke_async( + messages, + conversation_id=conversation_id, + config=cfg, + forwarded_extras=self._forwarded_business_fields(), + ) + stream_url = resp.stream_url + stream_headers = dict(resp.stream_headers) + session_id = stream_headers.get("X-Super-Agent-Session-Id", "") + + async def _factory(): + return api.stream_async( + stream_url, stream_headers=stream_headers, config=cfg + ) + + return InvokeStream( + conversation_id=resp.conversation_id, + session_id=session_id, + stream_url=stream_url, + stream_headers=stream_headers, + _stream_factory=_factory, + ) + + def invoke( + self, + messages: List[Dict[str, Any]], + *, + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + ) -> InvokeStream: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def list_conversations_async( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[ConversationInfo]: + """GET /conversations → ``List[ConversationInfo]``. + + 默认按 ``{"agentRuntimeName": self.name}`` 过滤 (仅返回当前 agent 的会话); + 传入 ``metadata={}`` 或自定义 dict 可覆盖默认过滤条件。 + """ + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + effective_metadata = ( + {"agentRuntimeName": self.name} if metadata is None else metadata + ) + raw_list = await api.list_conversations_async( + metadata=effective_metadata, config=cfg + ) + return [ + _conversation_info_from_dict( + raw, + fallback_conversation_id=str(raw.get("conversationId") or ""), + ) + for raw in raw_list + ] + + def list_conversations( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[ConversationInfo]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def get_conversation_async( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> ConversationInfo: + """GET /conversations/{id} → :class:`ConversationInfo` (缺字段用默认值).""" + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + data = await api.get_conversation_async(conversation_id, config=cfg) + return _conversation_info_from_dict( + data, fallback_conversation_id=conversation_id + ) + + def get_conversation( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> ConversationInfo: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def delete_conversation_async( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> None: + """DELETE /conversations/{id}.""" + cfg = self._resolve_config(config) + api = SuperAgentDataAPI(self.name, config=cfg) + await api.delete_conversation_async(conversation_id, config=cfg) + + def delete_conversation( + self, + conversation_id: str, + *, + config: Optional[Config] = None, + ) -> None: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + +def _to_message(raw: Dict[str, Any]) -> Message: + return Message( + role=str(raw.get("role") or ""), + content=str(raw.get("content") or ""), + message_id=raw.get("messageId") or raw.get("message_id"), + created_at=raw.get("createdAt") or raw.get("created_at"), + ) + + +def _conversation_info_from_dict( + data: Dict[str, Any], *, fallback_conversation_id: str +) -> ConversationInfo: + data = data or {} + messages_raw = data.get("messages") or [] + messages = [_to_message(m) for m in messages_raw if isinstance(m, dict)] + return ConversationInfo( + conversation_id=str( + data.get("conversationId") or fallback_conversation_id + ), + agent_id=str(data.get("agentId") or data.get("agent_id") or ""), + title=data.get("title"), + main_user_id=data.get("mainUserId") or data.get("main_user_id"), + sub_user_id=data.get("subUserId") or data.get("sub_user_id"), + created_at=int(data.get("createdAt") or data.get("created_at") or 0), + updated_at=int(data.get("updatedAt") or data.get("updated_at") or 0), + error_message=data.get("errorMessage") or data.get("error_message"), + invoke_info=data.get("invokeInfo") or data.get("invoke_info"), + messages=messages, + params=data.get("params"), + ) + + +__all__ = ["SuperAgent"] diff --git a/agentrun/super_agent/agui.py b/agentrun/super_agent/agui.py new file mode 100644 index 0000000..b6904d1 --- /dev/null +++ b/agentrun/super_agent/agui.py @@ -0,0 +1,136 @@ +"""Super Agent AG-UI 适配器 / Super Agent AG-UI Adapter + +把 :class:`InvokeStream` 的原始 :class:`SSEEvent` 解码为 ``ag_ui.core.BaseEvent`` +强类型事件 (**client 侧解码**)。 + +与 ``agentrun/server/agui_protocol.py`` (server 侧编码方向) 是反向关系, 互不依赖: +本文件只消费 ``ag_ui.core`` 的事件类, server 侧则负责产生它们。 + +使用: + +.. code-block:: python + + from agentrun.super_agent.agui import as_agui_events + + async for event in as_agui_events(stream): + # event 是 ag_ui.core.BaseEvent 子类 (如 TextMessageContentEvent) + ... + + # 想跳过未知事件 (如过渡期兼容), 传 on_unknown="skip" + async for event in as_agui_events(stream, on_unknown="skip"): + ... +""" + +from __future__ import annotations + +from typing import AsyncGenerator, Dict, Literal, Optional, Type + +from ag_ui.core import ( + BaseEvent, + CustomEvent, + MessagesSnapshotEvent, + RawEvent, + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + StateDeltaEvent, + StateSnapshotEvent, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, +) + +from .stream import InvokeStream, SSEEvent + +_EVENT_TYPE_TO_CLASS: Dict[str, Type[BaseEvent]] = { + "RUN_STARTED": RunStartedEvent, + "RUN_FINISHED": RunFinishedEvent, + "RUN_ERROR": RunErrorEvent, + "TEXT_MESSAGE_START": TextMessageStartEvent, + "TEXT_MESSAGE_CONTENT": TextMessageContentEvent, + "TEXT_MESSAGE_END": TextMessageEndEvent, + "TOOL_CALL_START": ToolCallStartEvent, + "TOOL_CALL_ARGS": ToolCallArgsEvent, + "TOOL_CALL_END": ToolCallEndEvent, + "TOOL_CALL_RESULT": ToolCallResultEvent, + "STATE_SNAPSHOT": StateSnapshotEvent, + "STATE_DELTA": StateDeltaEvent, + "MESSAGES_SNAPSHOT": MessagesSnapshotEvent, + "RAW": RawEvent, + "CUSTOM": CustomEvent, +} + + +UnknownMode = Literal["raise", "skip"] + + +def _data_preview(data: str, limit: int = 200) -> str: + if len(data) <= limit: + return data + return data[:limit] + "..." + + +def decode_sse_to_agui( + sse_event: SSEEvent, + *, + on_unknown: UnknownMode = "raise", +) -> Optional[BaseEvent]: + """把单个 :class:`SSEEvent` 解码为 AG-UI 事件. + + - 空 ``data`` (keepalive) → 返回 ``None`` + - 未知 ``event`` + ``on_unknown='raise'`` → 抛 ``ValueError`` + - 未知 ``event`` + ``on_unknown='skip'`` → 返回 ``None`` + - ``data`` 不合法 JSON 或 Pydantic 校验失败 → 抛 ``ValueError`` + (不论 ``on_unknown``, 因为这通常是真实错误) + """ + if sse_event.data == "": + return None + + event_name = sse_event.event + if not event_name or event_name not in _EVENT_TYPE_TO_CLASS: + if on_unknown == "skip": + return None + raise ValueError( + f"Unknown AG-UI event type {event_name!r}; data prefix:" + f" {_data_preview(sse_event.data)}" + ) + + cls = _EVENT_TYPE_TO_CLASS[event_name] + try: + return cls.model_validate_json(sse_event.data) + except Exception as exc: # JSONDecodeError, ValidationError, etc. + raise ValueError( + f"Failed to decode AG-UI {event_name!r} event: {exc};" + f" data prefix: {_data_preview(sse_event.data)}" + ) from exc + + +async def as_agui_events( + stream: InvokeStream, + *, + on_unknown: UnknownMode = "raise", +) -> AsyncGenerator[BaseEvent, None]: + """把 :class:`InvokeStream` 中的原始 :class:`SSEEvent` 解码为强类型流. + + 无论正常消费结束、中途异常、解码异常, 都保证 ``await stream.aclose()`` 被调用 + 以释放 httpx 连接。 + """ + try: + async for sse_event in stream: + agui_event = decode_sse_to_agui(sse_event, on_unknown=on_unknown) + if agui_event is None: + continue + yield agui_event + finally: + await stream.aclose() + + +__all__ = [ + "as_agui_events", + "decode_sse_to_agui", + "_EVENT_TYPE_TO_CLASS", +] diff --git a/agentrun/super_agent/api/__data_async_template.py b/agentrun/super_agent/api/__data_async_template.py new file mode 100644 index 0000000..9a7549e --- /dev/null +++ b/agentrun/super_agent/api/__data_async_template.py @@ -0,0 +1,318 @@ +"""Super Agent 数据面 API / Super Agent Data Plane API + +继承 :class:`agentrun.utils.data_api.DataAPI`, 复用: + +- ``_get_ram_data_endpoint``: RAM ``-ram`` 前缀改写 +- ``get_base_url`` / ``with_path``: URL 拼接 (含版本号) +- ``auth``: RAM 签名头生成 (``Agentrun-Authorization`` / ``x-acs-*``) + +本文件为 **模板** (``__data_async_template.py``); +当前 ``codegen`` 的 async→sync 转换不支持 async generator (``async for + yield``) +以及 ``async with ... as x`` 里的作用域保留, 所以 ``api/data.py`` 的同步骨架 +(同步方法第一版仅保留 ``NotImplementedError`` 占位) 目前 **手工维护**。 +运行 ``make codegen`` 会生成不可运行的版本, 请不要直接覆盖。 +""" + +import json +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + +import httpx + +from agentrun.super_agent.model import InvokeResponseData +from agentrun.super_agent.stream import parse_sse_async, SSEEvent +from agentrun.utils.config import Config +from agentrun.utils.data_api import DataAPI, ResourceType +from agentrun.utils.log import logger + +API_VERSION = "2025-09-10" +SUPER_AGENT_RESOURCE_PATH = "__SUPER_AGENT__" +SUPER_AGENT_NAMESPACE = ( + f"{API_VERSION}/super-agents/{SUPER_AGENT_RESOURCE_PATH}" +) + +_SYNC_UNSUPPORTED_MSG = ( + "sync version not supported, use *_async (see decision 14 in" + " openspec/changes/add-super-agent-sdk/design.md)" +) + + +class SuperAgentDataAPI(DataAPI): + """Super Agent 数据面 API (异步主路径 + 同步占位).""" + + def __init__( + self, + agent_runtime_name: str, + config: Optional[Config] = None, + ): + super().__init__( + resource_name=agent_runtime_name, + resource_type=ResourceType.Runtime, + namespace=SUPER_AGENT_NAMESPACE, + config=config, + ) + self.agent_runtime_name = agent_runtime_name + + def _build_invoke_body( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str], + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + # ``forwarded_extras`` 承载从 AgentRuntime 元数据读出的业务字段 + # (prompt/agents/tools/skills/sandboxes/workspaces/modelServiceName/modelName), + # 由上层 ``SuperAgent.invoke_async`` 注入。``metadata`` 和 ``conversationId`` + # 由 SDK 管理, 不允许 extras 覆盖。 + forwarded: Dict[str, Any] = dict(forwarded_extras or {}) + forwarded["metadata"] = {"agentRuntimeName": self.agent_runtime_name} + if conversation_id is not None: + forwarded["conversationId"] = conversation_id + return {"messages": list(messages), "forwardedProps": forwarded} + + def _parse_invoke_response( + self, payload: Dict[str, Any] + ) -> InvokeResponseData: + if not isinstance(payload, dict): + raise ValueError( + "Invalid invoke response: expected object, got" + f" {type(payload).__name__}" + ) + data = payload.get("data") + if not isinstance(data, dict): + raise ValueError("Invalid invoke response: missing data field") + conversation_id = data.get("conversationId") + if not conversation_id: + raise ValueError( + "Invalid invoke response: missing conversationId field" + ) + stream_url = data.get("url") + if not stream_url: + raise ValueError("Invalid invoke response: missing url field") + raw_headers = data.get("headers") or {} + stream_headers = {str(k): str(v) for k, v in dict(raw_headers).items()} + return InvokeResponseData( + conversation_id=str(conversation_id), + stream_url=str(stream_url), + stream_headers=stream_headers, + ) + + async def invoke_async( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> InvokeResponseData: + """Phase 1: POST /invoke, 返回 Phase 2 的 URL 与 headers.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path("invoke", config=cfg) + body = self._build_invoke_body( + messages, conversation_id, forwarded_extras + ) + body_bytes = json.dumps(body).encode("utf-8") + _, signed_headers, _ = self.auth( + url=url, + method="POST", + headers=cfg.get_headers(), + body=body_bytes, + config=cfg, + ) + signed_headers.setdefault("Content-Type", "application/json") + logger.debug("super_agent invoke request: POST %s body=%s", url, body) + + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.post( + url, headers=signed_headers, content=body_bytes + ) + resp.raise_for_status() + payload = resp.json() + logger.debug( + "super_agent invoke response: status=%d payload=%s", + resp.status_code, + payload, + ) + return self._parse_invoke_response(payload) + + def invoke( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> InvokeResponseData: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def stream_async( + self, + stream_url: str, + stream_headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> AsyncIterator[SSEEvent]: + """Phase 2: GET stream_url → 流式 yield SSEEvent. + + 合并优先级 (低 → 高): ``cfg.get_headers()`` → ``stream_headers`` → RAM 签名头。 + """ + cfg = Config.with_configs(self.config, config) + _, signed_headers, _ = self.auth( + url=stream_url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + merged_headers: Dict[str, str] = {} + merged_headers.update(cfg.get_headers()) + if stream_headers: + merged_headers.update(stream_headers) + merged_headers.update(signed_headers) + + logger.debug("super_agent stream request: GET %s", stream_url) + timeout = httpx.Timeout(cfg.get_timeout(), read=None) + client = httpx.AsyncClient(timeout=timeout) + try: + ctx = client.stream("GET", stream_url, headers=merged_headers) + response = await ctx.__aenter__() + try: + response.raise_for_status() + logger.debug( + "super_agent stream response: status=%d headers=%s", + response.status_code, + dict(response.headers), + ) + async for event in parse_sse_async(response): + logger.debug("super_agent stream event: %s", event) + yield event + finally: + await ctx.__aexit__(None, None, None) + finally: + await client.aclose() + + def stream( + self, + stream_url: str, + stream_headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> Iterator[SSEEvent]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def list_conversations_async( + self, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[Dict[str, Any]]: + """GET /conversations → 返回服务端 ``data.conversations`` 数组 (缺失时返回 [])。 + + ``metadata`` 若非 None, 会以 JSON 编码后通过 ``metadata`` query 参数下发; + 服务端按该 metadata 过滤 (例如 ``{"agentRuntimeName": "..."}``)。 + 不传则由服务端按当前 sub uid 过滤。 + """ + cfg = Config.with_configs(self.config, config) + query: Optional[Dict[str, Any]] = None + if metadata is not None: + query = {"metadata": json.dumps(metadata, ensure_ascii=False)} + url = self.with_path("conversations", query=query, config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent list_conversations request: GET %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.get(url, headers=signed_headers) + resp.raise_for_status() + payload = resp.json() if resp.text else {} + logger.debug( + "super_agent list_conversations response: status=%d payload=%s", + resp.status_code, + payload, + ) + if not isinstance(payload, dict): + return [] + data = payload.get("data") + if not isinstance(data, dict): + return [] + raw_list = data.get("conversations") + if not isinstance(raw_list, list): + return [] + return [item for item in raw_list if isinstance(item, dict)] + + def list_conversations( + self, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[Dict[str, Any]]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def get_conversation_async( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """GET /conversations/{id} → 返回服务端 ``data`` 字段 dict.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path(f"conversations/{conversation_id}", config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent get_conversation request: GET %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.get(url, headers=signed_headers) + resp.raise_for_status() + payload = resp.json() if resp.text else {} + logger.debug( + "super_agent get_conversation response: status=%d payload=%s", + resp.status_code, + payload, + ) + if not isinstance(payload, dict): + return {} + data = payload.get("data") + return data if isinstance(data, dict) else {} + + def get_conversation( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def delete_conversation_async( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> None: + """DELETE /conversations/{id}.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path(f"conversations/{conversation_id}", config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="DELETE", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent delete_conversation request: DELETE %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.delete(url, headers=signed_headers) + resp.raise_for_status() + logger.debug( + "super_agent delete_conversation response: status=%d body=%s", + resp.status_code, + resp.text, + ) + + def delete_conversation( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> None: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + +__all__ = [ + "API_VERSION", + "SUPER_AGENT_RESOURCE_PATH", + "SUPER_AGENT_NAMESPACE", + "SuperAgentDataAPI", +] diff --git a/agentrun/super_agent/api/__init__.py b/agentrun/super_agent/api/__init__.py new file mode 100644 index 0000000..28825ad --- /dev/null +++ b/agentrun/super_agent/api/__init__.py @@ -0,0 +1,5 @@ +"""Super Agent 内部 API 模块 / Super Agent internal API module""" + +from .data import SuperAgentDataAPI + +__all__ = ["SuperAgentDataAPI"] diff --git a/agentrun/super_agent/api/control.py b/agentrun/super_agent/api/control.py new file mode 100644 index 0000000..afaf665 --- /dev/null +++ b/agentrun/super_agent/api/control.py @@ -0,0 +1,480 @@ +"""Super Agent 控制面辅助函数 / Super Agent Control Plane Helpers + +本模块包含: +- 常量: API 版本号 / 协议类型 / 标签 / 资源路径 / RAM 数据域名列表 +- URL 工具: ``_add_ram_prefix_to_host`` / ``build_super_agent_endpoint`` +- AgentRuntime ↔ SuperAgent 的双向转换: + ``to_create_input`` / ``to_update_input`` / ``from_agent_runtime`` + / ``is_super_agent`` / ``parse_super_agent_config`` +- 为承载 ``externalEndpoint`` 的 Pydantic 与 Dara 层扩展类 + +不使用模板生成,保持单一来源,避免同步/异步重复维护。 +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional, TYPE_CHECKING +from urllib.parse import urlparse, urlunparse + +if TYPE_CHECKING: + from agentrun.super_agent.agent import SuperAgent + +from alibabacloud_agentrun20250910.models import ( + ProtocolConfiguration as _DaraProtocolConfiguration, +) +from pydantic import Field + +from agentrun.agent_runtime.model import ( + AgentRuntimeArtifact, + AgentRuntimeContainer, + AgentRuntimeCreateInput, + AgentRuntimeProtocolConfig, + AgentRuntimeUpdateInput, +) +from agentrun.agent_runtime.runtime import AgentRuntime +from agentrun.utils.config import Config +from agentrun.utils.model import NetworkConfig, NetworkMode + +# ─── 常量 ───────────────────────────────────────────── +API_VERSION = "2025-09-10" +SUPER_AGENT_PROTOCOL_TYPE = "SUPER_AGENT" +# ``SUPER_AGENT_TAG`` 标识下游 AgentRuntime 是超级 Agent, 用于 list 过滤。 +# 写入 ``systemTags`` 字段 (由服务端原生支持), create/update/list 的 system_tags 参数统一使用。 +SUPER_AGENT_TAG = "x-agentrun-super" +# ``EXTERNAL_TAG`` 标识下游 AgentRuntime 由外部 (SuperAgent) 托管调用, 不由 AgentRun 直接托管。 +# 保留常量以便外部消费者引用; 创建超级 Agent 时不再写入此 tag。 +EXTERNAL_TAG = "x-agentrun-external" +# 创建下游 AgentRuntime 时固定写入的 systemTags 列表: ``[SUPER_AGENT_TAG]`` (仅一个)。 +SUPER_AGENT_CREATE_TAGS = [SUPER_AGENT_TAG] +SUPER_AGENT_RESOURCE_PATH = "__SUPER_AGENT__" +SUPER_AGENT_INVOKE_PATH = "/invoke" +SUPER_AGENT_NAMESPACE = ( + f"{API_VERSION}/super-agents/{SUPER_AGENT_RESOURCE_PATH}" +) + +_RAM_DATA_DOMAINS = ("agentrun-data", "funagent-data-pre") + +# SUPER_AGENT 不跑用户 container/code, 但服务端强制要求 artifact/container_configuration 非空, +# 这里给一个占位镜像地址即可。region 取杭州仅为了格式合法, 服务端不会实际 pull。 +_PLACEHOLDER_IMAGE = ( + "registry.cn-hangzhou.aliyuncs.com/agentrun/super-agent-placeholder:v1" +) + + +# ─── URL 工具 ────────────────────────────────────────── + + +def _add_ram_prefix_to_host(base_url: str) -> str: + """给已知 data host 加 ``-ram`` 前缀. + + 仅当 host 第二段命中 :data:`_RAM_DATA_DOMAINS` 时改写为 + ``-ram.<其余>``, 其他情况原样返回。 + 与 :meth:`agentrun.utils.data_api.DataAPI._get_ram_data_endpoint` 同源。 + """ + parsed = urlparse(base_url) + if not parsed.netloc: + return base_url + if not any(f".{d}." in parsed.netloc for d in _RAM_DATA_DOMAINS): + return base_url + parts = parsed.netloc.split(".", 1) + if len(parts) != 2: + return base_url + ram_netloc = parts[0] + "-ram." + parts[1] + return urlunparse(( + parsed.scheme, + ram_netloc, + parsed.path or "", + parsed.params, + parsed.query, + parsed.fragment, + )) + + +def build_super_agent_endpoint(cfg: Optional[Config] = None) -> str: + """构造 ``protocolConfiguration.externalEndpoint`` 的存储值 (不带版本号). + + 基于 :meth:`Config.get_data_endpoint` + :func:`_add_ram_prefix_to_host` + + 追加 ``/super-agents/__SUPER_AGENT__``, 自动适配生产 / 预发 / 自定义网关。 + """ + cfg = Config.with_configs(cfg) + base = cfg.get_data_endpoint() + ram_base = _add_ram_prefix_to_host(base) + return f"{ram_base.rstrip('/')}/super-agents/{SUPER_AGENT_RESOURCE_PATH}" + + +# ─── Pydantic 扩展类 ──────────────────────────────────── +class SuperAgentProtocolConfig(AgentRuntimeProtocolConfig): + """承载 ``protocol_settings`` + ``external_endpoint`` 的 Pydantic 扩展. + + 基类 ``AgentRuntimeProtocolConfig`` 的 ``type`` 字段是 ``HTTP / MCP`` 枚举, + 本子类通过 ``model_construct`` 绕过校验存入字符串 ``"SUPER_AGENT"``。 + """ + + protocol_settings: Optional[List[Dict[str, Any]]] = Field( + alias="protocolSettings", default=None + ) + external_endpoint: Optional[str] = Field( + alias="externalEndpoint", default=None + ) + + +class _SuperAgentCreateInput(AgentRuntimeCreateInput): + """默认使用 ``serialize_as_any=True`` 的 create input, 保留子类 extras. + + ``external_agent_endpoint_url`` 是基类 ``AgentRuntimeMutableProps`` 没有覆盖 + 的顶层字段, 这里显式补齐 (alias 由 BaseModel 的 ``to_camel_case`` 生成 → + ``externalAgentEndpointUrl``), 用于承载超级 Agent 的数据面入口地址。 + """ + + external_agent_endpoint_url: Optional[str] = None + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + kwargs.setdefault("serialize_as_any", True) + return super().model_dump(**kwargs) + + +class _SuperAgentUpdateInput(AgentRuntimeUpdateInput): + """默认使用 ``serialize_as_any=True`` 的 update input, 保留子类 extras.""" + + external_agent_endpoint_url: Optional[str] = None + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + kwargs.setdefault("serialize_as_any", True) + return super().model_dump(**kwargs) + + +# ─── Dara 模型猴补丁 ────────────────────────────────────── +# 当前版 Dara SDK 缺 ``ProtocolConfiguration.externalEndpoint`` 字段, 会在 +# Pydantic ↔ Dara ``from_map / to_map`` roundtrip 中静默丢失。补丁延迟到 +# ``SuperAgentClient`` 实例化时 (见 ``ensure_super_agent_patches_applied``) 才 +# 触发, 避免仅 import 本模块的调用方被动承担全局副作用。补丁用哨兵属性保证 +# 幂等, 重复调用安全。TODO: 等 Dara SDK 原生支持后删除。 +# +# ``tags`` (原 hack 写入) 已由 SDK 原生 ``systemTags`` 字段替代, 不再需要任何 +# 补丁; create/update/list 统一走 ``system_tags`` 参数。 + + +def _patch_dara_protocol_configuration() -> None: + """补齐 ``ProtocolConfiguration.externalEndpoint`` 的 from_map/to_map 读写.""" + if getattr(_DaraProtocolConfiguration, "_super_agent_patched", False): + return + + _orig_to_map = _DaraProtocolConfiguration.to_map + _orig_from_map = _DaraProtocolConfiguration.from_map + + def _patched_to_map(self: _DaraProtocolConfiguration) -> Dict[str, Any]: + result = _orig_to_map(self) + ee = getattr(self, "external_endpoint", None) + if ee is not None: + result["externalEndpoint"] = ee + return result + + def _patched_from_map( + self: _DaraProtocolConfiguration, m: Optional[Dict[str, Any]] = None + ) -> _DaraProtocolConfiguration: + _orig_from_map(self, m) + if m and m.get("externalEndpoint") is not None: + self.external_endpoint = m.get("externalEndpoint") + return self + + _DaraProtocolConfiguration.to_map = _patched_to_map # type: ignore[assignment] + _DaraProtocolConfiguration.from_map = _patched_from_map # type: ignore[assignment] + _DaraProtocolConfiguration._super_agent_patched = True # type: ignore[attr-defined] + + +def ensure_super_agent_patches_applied() -> None: + """按需应用 Dara SDK 兼容补丁 (幂等)。 + + 由 ``SuperAgentClient.__init__`` 调用。如果调用方直接使用 + ``to_create_input`` / ``to_update_input`` 并自己构造 Dara 输入, 也应在 + Pydantic → Dara 转换前调用一次本函数。 + """ + _patch_dara_protocol_configuration() + + +# ─── AgentRuntime ↔ SuperAgent 转换 ──────────────────────── +def _business_fields_from_args( + *, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, +) -> Dict[str, Any]: + """把业务字段 (None 保留为 None) 收拢成 dict, 供 ``protocolSettings.config`` 使用.""" + return { + "prompt": prompt, + "agents": agents if agents is not None else [], + "tools": tools if tools is not None else [], + "skills": skills if skills is not None else [], + "sandboxes": sandboxes if sandboxes is not None else [], + "workspaces": workspaces if workspaces is not None else [], + "modelServiceName": model_service_name, + "modelName": model_name, + } + + +def _build_protocol_settings_config( + *, name: str, business: Dict[str, Any] +) -> str: + """构造 ``protocolSettings[0].config`` 的 JSON 字符串. + + 新结构: 顶层 ``path`` / ``headers`` / ``body``, 业务字段收拢到 + ``body.forwardedProps`` (开放字典, 语义 "any, merge")。 + """ + forwarded_props: Dict[str, Any] = { + "prompt": business.get("prompt"), + "agents": business.get("agents") or [], + "tools": business.get("tools") or [], + "skills": business.get("skills") or [], + "sandboxes": business.get("sandboxes") or [], + "workspaces": business.get("workspaces") or [], + "modelServiceName": business.get("modelServiceName"), + "modelName": business.get("modelName"), + "metadata": {"agentRuntimeName": name}, + } + cfg_dict: Dict[str, Any] = { + "path": SUPER_AGENT_INVOKE_PATH, + "headers": {}, + "body": {"forwardedProps": forwarded_props}, + } + return json.dumps(cfg_dict, ensure_ascii=False) + + +def _build_protocol_configuration( + *, + name: str, + business: Dict[str, Any], + cfg: Optional[Config], +) -> SuperAgentProtocolConfig: + """构造超级 Agent 的 ``protocolConfiguration`` Pydantic 模型.""" + config_json = _build_protocol_settings_config(name=name, business=business) + settings: List[Dict[str, Any]] = [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "name": name, + "path": SUPER_AGENT_INVOKE_PATH, + "config": config_json, + }] + pc = SuperAgentProtocolConfig.model_construct( + type=SUPER_AGENT_PROTOCOL_TYPE, + protocol_settings=settings, + external_endpoint=build_super_agent_endpoint(cfg), + ) + return pc + + +def to_create_input( + name: str, + *, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + cfg: Optional[Config] = None, +) -> AgentRuntimeCreateInput: + """把超级 Agent 业务字段转为 :class:`AgentRuntimeCreateInput`.""" + business = _business_fields_from_args( + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + ) + pc = _build_protocol_configuration(name=name, business=business, cfg=cfg) + # SUPER_AGENT 是平台托管运行时, 不跑用户代码/容器, 但服务端仍要求 + # artifact_type / network_configuration 非空. 这里给占位默认值即可. + return _SuperAgentCreateInput.model_construct( + agent_runtime_name=name, + description=description, + protocol_configuration=pc, + system_tags=list(SUPER_AGENT_CREATE_TAGS), + # 超级 Agent 的数据面入口 (与 protocolConfiguration.externalEndpoint 同值)。 + external_agent_endpoint_url=build_super_agent_endpoint(cfg), + # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 + artifact_type=AgentRuntimeArtifact.CONTAINER, + container_configuration=AgentRuntimeContainer(image=_PLACEHOLDER_IMAGE), + network_configuration=NetworkConfig(network_mode=NetworkMode.PUBLIC), + ) + + +def to_update_input( + name: str, + merged: Dict[str, Any], + cfg: Optional[Config] = None, +) -> AgentRuntimeUpdateInput: + """把合并后的业务字段转为 :class:`AgentRuntimeUpdateInput` (全量替换).""" + business = _business_fields_from_args( + prompt=merged.get("prompt"), + agents=merged.get("agents"), + tools=merged.get("tools"), + skills=merged.get("skills"), + sandboxes=merged.get("sandboxes"), + workspaces=merged.get("workspaces"), + model_service_name=merged.get("model_service_name"), + model_name=merged.get("model_name"), + ) + pc = _build_protocol_configuration(name=name, business=business, cfg=cfg) + return _SuperAgentUpdateInput.model_construct( + agent_runtime_name=name, + description=merged.get("description"), + protocol_configuration=pc, + system_tags=list(SUPER_AGENT_CREATE_TAGS), + # 超级 Agent 的数据面入口 (与 protocolConfiguration.externalEndpoint 同值)。 + external_agent_endpoint_url=build_super_agent_endpoint(cfg), + # 占位 artifact: SUPER_AGENT 不跑用户 container/code, 但服务端要求非空。 + artifact_type=AgentRuntimeArtifact.CONTAINER, + container_configuration=AgentRuntimeContainer(image=_PLACEHOLDER_IMAGE), + network_configuration=NetworkConfig(network_mode=NetworkMode.PUBLIC), + ) + + +def _extract_protocol_configuration(rt: AgentRuntime) -> Dict[str, Any]: + """统一把 rt.protocol_configuration 转为 dict (兼容 dict / pydantic 两种形态).""" + pc = getattr(rt, "protocol_configuration", None) + if pc is None: + return {} + if isinstance(pc, dict): + return pc + # Pydantic 模型: 用 model_dump(serialize_as_any=True) 保留 extras + try: + return pc.model_dump(by_alias=True, serialize_as_any=True) + except TypeError: + return pc.model_dump(by_alias=True) + + +def _extract_protocol_settings(pc_dict: Dict[str, Any]) -> List[Dict[str, Any]]: + """从 protocolConfiguration dict 中取出 protocolSettings 列表.""" + for key in ("protocolSettings", "protocol_settings"): + v = pc_dict.get(key) + if isinstance(v, list): + return v + return [] + + +def is_super_agent(rt: AgentRuntime) -> bool: + """判断一个 AgentRuntime 是否为超级 Agent.""" + pc_dict = _extract_protocol_configuration(rt) + if not pc_dict: + return False + settings = _extract_protocol_settings(pc_dict) + if not settings: + return False + first = settings[0] if isinstance(settings[0], dict) else {} + return first.get("type") == SUPER_AGENT_PROTOCOL_TYPE + + +def _flatten_protocol_config(cfg: Any) -> Dict[str, Any]: + """把 ``protocolSettings[0].config`` 解析结果压平为扁平业务字段 dict. + + 兼容两种物理布局: + - 新: ``{"path": ..., "headers": ..., "body": {"forwardedProps": {...}}}`` + - 旧: 业务字段直接在根 (历史 AgentRuntime, 迁移前写入) + + 两种结构都返回扁平的业务字段 dict, 上游 + :func:`from_agent_runtime` 无需感知物理布局差异。 + """ + if not isinstance(cfg, dict): + return {} + body = cfg.get("body") + if isinstance(body, dict): + forwarded = body.get("forwardedProps") + if isinstance(forwarded, dict): + return forwarded + return cfg + + +def parse_super_agent_config(rt: AgentRuntime) -> Dict[str, Any]: + """反解 ``protocolSettings[0].config`` 为扁平业务字段 dict. + + 如果 config 缺失或非法 JSON, 返回空 dict (不抛异常)。 + 新旧嵌套布局由 :func:`_flatten_protocol_config` 统一拍平。 + """ + pc_dict = _extract_protocol_configuration(rt) + if not pc_dict: + return {} + settings = _extract_protocol_settings(pc_dict) + if not settings: + return {} + first = settings[0] if isinstance(settings[0], dict) else {} + raw_config = first.get("config") + if not raw_config: + return {} + if isinstance(raw_config, dict): + return _flatten_protocol_config(raw_config) + if isinstance(raw_config, str): + try: + parsed = json.loads(raw_config) + except (TypeError, ValueError): + return {} + return _flatten_protocol_config(parsed) + return {} + + +def _get_external_endpoint(rt: AgentRuntime) -> str: + pc_dict = _extract_protocol_configuration(rt) + return ( + pc_dict.get("externalEndpoint") + or pc_dict.get("external_endpoint", "") + or "" + ) + + +def from_agent_runtime(rt: AgentRuntime) -> "SuperAgent": + """反解 AgentRuntime → SuperAgent 实例 (不注入 ``_client``).""" + # 延迟导入避免循环 + from agentrun.super_agent.agent import SuperAgent + + business = parse_super_agent_config(rt) + return SuperAgent( + name=getattr(rt, "agent_runtime_name", None) or "", + description=getattr(rt, "description", None), + prompt=business.get("prompt"), + agents=list(business.get("agents") or []), + tools=list(business.get("tools") or []), + skills=list(business.get("skills") or []), + sandboxes=list(business.get("sandboxes") or []), + workspaces=list(business.get("workspaces") or []), + model_service_name=business.get("modelServiceName"), + model_name=business.get("modelName"), + agent_runtime_id=getattr(rt, "agent_runtime_id", None) or "", + arn=getattr(rt, "agent_runtime_arn", None) or "", + status=str(getattr(rt, "status", "") or ""), + created_at=getattr(rt, "created_at", None) or "", + last_updated_at=getattr(rt, "last_updated_at", None) or "", + external_endpoint=_get_external_endpoint(rt), + ) + + +__all__ = [ + "API_VERSION", + "SUPER_AGENT_PROTOCOL_TYPE", + "SUPER_AGENT_TAG", + "EXTERNAL_TAG", + "SUPER_AGENT_CREATE_TAGS", + "SUPER_AGENT_RESOURCE_PATH", + "SUPER_AGENT_INVOKE_PATH", + "SUPER_AGENT_NAMESPACE", + "SuperAgentProtocolConfig", + "_SuperAgentCreateInput", + "_SuperAgentUpdateInput", + "build_super_agent_endpoint", + "_add_ram_prefix_to_host", + "to_create_input", + "to_update_input", + "from_agent_runtime", + "is_super_agent", + "parse_super_agent_config", + "ensure_super_agent_patches_applied", +] diff --git a/agentrun/super_agent/api/data.py b/agentrun/super_agent/api/data.py new file mode 100644 index 0000000..947e0b4 --- /dev/null +++ b/agentrun/super_agent/api/data.py @@ -0,0 +1,324 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/super_agent/api/__data_async_template.py + +Super Agent 数据面 API / Super Agent Data Plane API + +继承 :class:`agentrun.utils.data_api.DataAPI`, 复用: + +- ``_get_ram_data_endpoint``: RAM ``-ram`` 前缀改写 +- ``get_base_url`` / ``with_path``: URL 拼接 (含版本号) +- ``auth``: RAM 签名头生成 (``Agentrun-Authorization`` / ``x-acs-*``) + +同步方法第一版抛 ``NotImplementedError``, 仅保留骨架以备未来扩展。 +""" + +import json +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + +import httpx + +from agentrun.super_agent.model import InvokeResponseData +from agentrun.super_agent.stream import parse_sse_async, SSEEvent +from agentrun.utils.config import Config +from agentrun.utils.data_api import DataAPI, ResourceType +from agentrun.utils.log import logger + +API_VERSION = "2025-09-10" +SUPER_AGENT_RESOURCE_PATH = "__SUPER_AGENT__" +SUPER_AGENT_NAMESPACE = ( + f"{API_VERSION}/super-agents/{SUPER_AGENT_RESOURCE_PATH}" +) + +_SYNC_UNSUPPORTED_MSG = ( + "sync version not supported, use *_async (see decision 14 in" + " openspec/changes/add-super-agent-sdk/design.md)" +) + + +class SuperAgentDataAPI(DataAPI): + """Super Agent 数据面 API (异步主路径 + 同步占位).""" + + def __init__( + self, + agent_runtime_name: str, + config: Optional[Config] = None, + ): + super().__init__( + resource_name=agent_runtime_name, + resource_type=ResourceType.Runtime, + namespace=SUPER_AGENT_NAMESPACE, + config=config, + ) + self.agent_runtime_name = agent_runtime_name + + def _build_invoke_body( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str], + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + # ``forwarded_extras`` 承载从 AgentRuntime 元数据读出的业务字段 + # (prompt/agents/tools/skills/sandboxes/workspaces/modelServiceName/modelName), + # 由上层 ``SuperAgent.invoke_async`` 注入。``metadata`` 和 ``conversationId`` + # 由 SDK 管理, 不允许 extras 覆盖。 + forwarded: Dict[str, Any] = dict(forwarded_extras or {}) + forwarded["metadata"] = {"agentRuntimeName": self.agent_runtime_name} + if conversation_id is not None: + forwarded["conversationId"] = conversation_id + return {"messages": list(messages), "forwardedProps": forwarded} + + def _parse_invoke_response( + self, payload: Dict[str, Any] + ) -> InvokeResponseData: + if not isinstance(payload, dict): + raise ValueError( + "Invalid invoke response: expected object, got" + f" {type(payload).__name__}" + ) + data = payload.get("data") + if not isinstance(data, dict): + raise ValueError("Invalid invoke response: missing data field") + conversation_id = data.get("conversationId") + if not conversation_id: + raise ValueError( + "Invalid invoke response: missing conversationId field" + ) + stream_url = data.get("url") + if not stream_url: + raise ValueError("Invalid invoke response: missing url field") + raw_headers = data.get("headers") or {} + stream_headers = {str(k): str(v) for k, v in dict(raw_headers).items()} + return InvokeResponseData( + conversation_id=str(conversation_id), + stream_url=str(stream_url), + stream_headers=stream_headers, + ) + + async def invoke_async( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> InvokeResponseData: + """Phase 1: POST /invoke, 返回 Phase 2 的 URL 与 headers.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path("invoke", config=cfg) + body = self._build_invoke_body( + messages, conversation_id, forwarded_extras + ) + body_bytes = json.dumps(body).encode("utf-8") + _, signed_headers, _ = self.auth( + url=url, + method="POST", + headers=cfg.get_headers(), + body=body_bytes, + config=cfg, + ) + signed_headers.setdefault("Content-Type", "application/json") + logger.debug("super_agent invoke request: POST %s body=%s", url, body) + + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.post( + url, headers=signed_headers, content=body_bytes + ) + resp.raise_for_status() + payload = resp.json() + logger.debug( + "super_agent invoke response: status=%d payload=%s", + resp.status_code, + payload, + ) + return self._parse_invoke_response(payload) + + def invoke( + self, + messages: List[Dict[str, Any]], + conversation_id: Optional[str] = None, + config: Optional[Config] = None, + forwarded_extras: Optional[Dict[str, Any]] = None, + ) -> InvokeResponseData: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def stream_async( + self, + stream_url: str, + stream_headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> AsyncIterator[SSEEvent]: + """Phase 2: GET stream_url → 流式 yield SSEEvent. + + 合并优先级 (低 → 高): ``cfg.get_headers()`` → ``stream_headers`` → RAM 签名头。 + """ + cfg = Config.with_configs(self.config, config) + _, signed_headers, _ = self.auth( + url=stream_url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + merged_headers: Dict[str, str] = {} + merged_headers.update(cfg.get_headers()) + if stream_headers: + merged_headers.update(stream_headers) + merged_headers.update(signed_headers) + + logger.debug("super_agent stream request: GET %s", stream_url) + timeout = httpx.Timeout(cfg.get_timeout(), read=None) + client = httpx.AsyncClient(timeout=timeout) + try: + ctx = client.stream("GET", stream_url, headers=merged_headers) + response = await ctx.__aenter__() + try: + response.raise_for_status() + logger.debug( + "super_agent stream response: status=%d headers=%s", + response.status_code, + dict(response.headers), + ) + async for event in parse_sse_async(response): + logger.debug("super_agent stream event: %s", event) + yield event + finally: + await ctx.__aexit__(None, None, None) + finally: + await client.aclose() + + def stream( + self, + stream_url: str, + stream_headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> Iterator[SSEEvent]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def list_conversations_async( + self, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[Dict[str, Any]]: + """GET /conversations → 返回服务端 ``data.conversations`` 数组 (缺失时返回 [])。 + + ``metadata`` 若非 None, 会以 JSON 编码后通过 ``metadata`` query 参数下发; + 服务端按该 metadata 过滤 (例如 ``{"agentRuntimeName": "..."}``)。 + 不传则由服务端按当前 sub uid 过滤。 + """ + cfg = Config.with_configs(self.config, config) + query: Optional[Dict[str, Any]] = None + if metadata is not None: + query = {"metadata": json.dumps(metadata, ensure_ascii=False)} + url = self.with_path("conversations", query=query, config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent list_conversations request: GET %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.get(url, headers=signed_headers) + resp.raise_for_status() + payload = resp.json() if resp.text else {} + logger.debug( + "super_agent list_conversations response: status=%d payload=%s", + resp.status_code, + payload, + ) + if not isinstance(payload, dict): + return [] + data = payload.get("data") + if not isinstance(data, dict): + return [] + raw_list = data.get("conversations") + if not isinstance(raw_list, list): + return [] + return [item for item in raw_list if isinstance(item, dict)] + + def list_conversations( + self, + metadata: Optional[Dict[str, Any]] = None, + config: Optional[Config] = None, + ) -> List[Dict[str, Any]]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def get_conversation_async( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """GET /conversations/{id} → 返回服务端 ``data`` 字段 dict.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path(f"conversations/{conversation_id}", config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="GET", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent get_conversation request: GET %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.get(url, headers=signed_headers) + resp.raise_for_status() + payload = resp.json() if resp.text else {} + logger.debug( + "super_agent get_conversation response: status=%d payload=%s", + resp.status_code, + payload, + ) + if not isinstance(payload, dict): + return {} + data = payload.get("data") + return data if isinstance(data, dict) else {} + + def get_conversation( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + async def delete_conversation_async( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> None: + """DELETE /conversations/{id}.""" + cfg = Config.with_configs(self.config, config) + url = self.with_path(f"conversations/{conversation_id}", config=cfg) + _, signed_headers, _ = self.auth( + url=url, + method="DELETE", + headers=cfg.get_headers(), + config=cfg, + ) + logger.debug("super_agent delete_conversation request: DELETE %s", url) + async with httpx.AsyncClient(timeout=cfg.get_timeout()) as client: + resp = await client.delete(url, headers=signed_headers) + resp.raise_for_status() + logger.debug( + "super_agent delete_conversation response: status=%d body=%s", + resp.status_code, + resp.text, + ) + + def delete_conversation( + self, + conversation_id: str, + config: Optional[Config] = None, + ) -> None: + raise NotImplementedError(_SYNC_UNSUPPORTED_MSG) + + +__all__ = [ + "API_VERSION", + "SUPER_AGENT_RESOURCE_PATH", + "SUPER_AGENT_NAMESPACE", + "SuperAgentDataAPI", +] diff --git a/agentrun/super_agent/client.py b/agentrun/super_agent/client.py new file mode 100644 index 0000000..2f0da6b --- /dev/null +++ b/agentrun/super_agent/client.py @@ -0,0 +1,531 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/super_agent/__client_async_template.py + +SuperAgentClient / 超级 Agent 客户端 + +对外入口: CRUDL (create / get / update / delete / list / list_all) 同步 + 异步双写。 +内部持有一个 :class:`AgentRuntimeClient` 实例, 通过 ``api/control.py`` 的 +转换函数把 ``SuperAgent`` 与 ``AgentRuntime`` 互相映射。 + +list 固定按 systemTag ``x-agentrun-super`` 过滤, 不接受用户自定义 tag。 +""" + +import asyncio +import time +from typing import Any, List, Optional + +from alibabacloud_agentrun20250910.models import ( + CreateAgentRuntimeInput, + UpdateAgentRuntimeInput, +) + +from agentrun.agent_runtime.api import AgentRuntimeControlAPI +from agentrun.agent_runtime.client import AgentRuntimeClient +from agentrun.agent_runtime.model import AgentRuntimeListInput +from agentrun.agent_runtime.runtime import AgentRuntime +from agentrun.super_agent.agent import SuperAgent +from agentrun.super_agent.api.control import ( + ensure_super_agent_patches_applied, + from_agent_runtime, + is_super_agent, + SUPER_AGENT_TAG, + to_create_input, + to_update_input, +) +from agentrun.utils.config import Config +from agentrun.utils.log import logger +from agentrun.utils.model import Status + +# 公开 API 签名故意保持 ``Optional[X] = None`` 对外简洁; +# ``_UNSET`` 仅用于内部区分 "未传" 与 "显式 None (= 清空)". +_UNSET: Any = object() + +# create/update 轮询默认参数 +_WAIT_INTERVAL_SECONDS = 3 +_WAIT_TIMEOUT_SECONDS = 300 + + +def _raise_if_failed(rt: AgentRuntime, action: str) -> None: + """若 rt 处于失败态, 抛出带 status_reason 的 RuntimeError.""" + status = getattr(rt, "status", None) + status_str = str(status) if status is not None else "" + if status_str in { + Status.CREATE_FAILED.value, + Status.UPDATE_FAILED.value, + Status.DELETE_FAILED.value, + }: + reason = getattr(rt, "status_reason", None) or "(no reason)" + name = getattr(rt, "agent_runtime_name", None) or "(unknown)" + raise RuntimeError( + f"Super agent {action} failed: name={name!r} status={status_str} " + f"reason={reason}" + ) + + +def _merge(current: dict, updates: dict) -> dict: + """把 ``updates`` 中非 ``_UNSET`` 的字段合并到 ``current`` (None 表示清空).""" + merged = dict(current) + for key, value in updates.items(): + if value is _UNSET: + continue + merged[key] = value + return merged + + +def _super_agent_to_business_dict(agent: SuperAgent) -> dict: + return { + "description": agent.description, + "prompt": agent.prompt, + "agents": list(agent.agents), + "tools": list(agent.tools), + "skills": list(agent.skills), + "sandboxes": list(agent.sandboxes), + "workspaces": list(agent.workspaces), + "model_service_name": agent.model_service_name, + "model_name": agent.model_name, + } + + +class SuperAgentClient: + """Super Agent CRUDL 客户端.""" + + def __init__(self, config: Optional[Config] = None) -> None: + # 按需打 Dara SDK 兼容补丁 (幂等)。放在本构造函数里, 让 "仅 import + # agentrun.super_agent" 的调用方不被动承担全局 SDK 副作用。 + ensure_super_agent_patches_applied() + self.config = config + self._rt = AgentRuntimeClient(config=config) + # create/update 绕过 AgentRuntimeClient 的 artifact_type 校验 (SUPER_AGENT 不需要 code/container), + # 并通过 ``ProtocolConfiguration`` 的 monkey-patch 保留 ``externalEndpoint`` 字段。 + self._rt_control = AgentRuntimeControlAPI(config=config) + + async def _wait_final_async( + self, + agent_runtime_id: str, + *, + config: Optional[Config] = None, + interval_seconds: int = _WAIT_INTERVAL_SECONDS, + timeout_seconds: int = _WAIT_TIMEOUT_SECONDS, + ) -> AgentRuntime: + """轮询 get 直到 status 进入最终态 (READY / *_FAILED).""" + cfg = Config.with_configs(self.config, config) + start = time.monotonic() + while True: + rt = await self._rt.get_async(agent_runtime_id, config=cfg) + status = getattr(rt, "status", None) + logger.debug( + "super agent %s poll status=%s", agent_runtime_id, status + ) + if Status.is_final_status(status): + return rt + if time.monotonic() - start > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for super agent {agent_runtime_id!r}" + f" to reach final status (last status={status})" + ) + await asyncio.sleep(interval_seconds) + + def _wait_final( + self, + agent_runtime_id: str, + *, + config: Optional[Config] = None, + interval_seconds: int = _WAIT_INTERVAL_SECONDS, + timeout_seconds: int = _WAIT_TIMEOUT_SECONDS, + ) -> AgentRuntime: + """同步版 _wait_final_async.""" + cfg = Config.with_configs(self.config, config) + start = time.monotonic() + while True: + rt = self._rt.get(agent_runtime_id, config=cfg) + status = getattr(rt, "status", None) + logger.debug( + "super agent %s poll status=%s", agent_runtime_id, status + ) + if Status.is_final_status(status): + return rt + if time.monotonic() - start > timeout_seconds: + raise TimeoutError( + f"Timed out waiting for super agent {agent_runtime_id!r}" + f" to reach final status (last status={status})" + ) + time.sleep(interval_seconds) + + # ─── Create ────────────────────────────────────── + async def create_async( + self, + *, + name: str, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> SuperAgent: + """异步创建超级 Agent.""" + cfg = Config.with_configs(self.config, config) + rt_input = to_create_input( + name, + description=description, + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + cfg=cfg, + ) + dara_input = CreateAgentRuntimeInput().from_map(rt_input.model_dump()) + result = await self._rt_control.create_agent_runtime_async( + dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + # 轮询直到进入最终态; 失败则抛出带 status_reason 的错误。 + agent_id = getattr(rt, "agent_runtime_id", None) + if agent_id and not Status.is_final_status(getattr(rt, "status", None)): + rt = await self._wait_final_async(agent_id, config=cfg) + _raise_if_failed(rt, action="create") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def create( + self, + *, + name: str, + description: Optional[str] = None, + prompt: Optional[str] = None, + agents: Optional[List[str]] = None, + tools: Optional[List[str]] = None, + skills: Optional[List[str]] = None, + sandboxes: Optional[List[str]] = None, + workspaces: Optional[List[str]] = None, + model_service_name: Optional[str] = None, + model_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> SuperAgent: + """同步创建超级 Agent.""" + cfg = Config.with_configs(self.config, config) + rt_input = to_create_input( + name, + description=description, + prompt=prompt, + agents=agents, + tools=tools, + skills=skills, + sandboxes=sandboxes, + workspaces=workspaces, + model_service_name=model_service_name, + model_name=model_name, + cfg=cfg, + ) + dara_input = CreateAgentRuntimeInput().from_map(rt_input.model_dump()) + result = self._rt_control.create_agent_runtime(dara_input, config=cfg) + rt = AgentRuntime.from_inner_object(result) + agent_id = getattr(rt, "agent_runtime_id", None) + if agent_id: + rt = self._wait_final(agent_id, config=cfg) + _raise_if_failed(rt, action="create") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Get ────────────────────────────────────────── + # Aliyun 控制面 get/delete/update 接口只认 ``agent_runtime_id`` (URN), + # 不认 resource_name; ``_find_rt_by_name*`` 用 list + 名称匹配来解析 id. + def _find_rt_by_name(self, name: str, config: Optional[Config]) -> Any: + cfg = Config.with_configs(self.config, config) + page_number = 1 + page_size = 50 + while True: + runtimes = self._rt.list( + AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + system_tags=SUPER_AGENT_TAG, + ), + config=cfg, + ) + for rt in runtimes: + if getattr(rt, "agent_runtime_name", None) == name: + return rt + if len(runtimes) < page_size: + return None + page_number += 1 + + async def _find_rt_by_name_async( + self, name: str, config: Optional[Config] + ) -> Any: + cfg = Config.with_configs(self.config, config) + page_number = 1 + page_size = 50 + while True: + runtimes = await self._rt.list_async( + AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + system_tags=SUPER_AGENT_TAG, + ), + config=cfg, + ) + for rt in runtimes: + if getattr(rt, "agent_runtime_name", None) == name: + return rt + if len(runtimes) < page_size: + return None + page_number += 1 + + async def get_async( + self, name: str, *, config: Optional[Config] = None + ) -> SuperAgent: + """异步获取超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def get(self, name: str, *, config: Optional[Config] = None) -> SuperAgent: + """同步获取超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Update (read-merge-write) ───────────────────── + # 参数默认值 ``_UNSET`` 是内部哨兵 (object())。为保留 IDE 自动补全与 mypy + # 类型检查, 签名保持精确类型标注, 对 ``= _UNSET`` 的赋值加 ``type: ignore``。 + # 未传 = 保持不变, 显式传 None = 清空字段。 + async def update_async( + self, + name: str, + *, + description: Optional[str] = _UNSET, # type: ignore[assignment] + prompt: Optional[str] = _UNSET, # type: ignore[assignment] + agents: Optional[List[str]] = _UNSET, # type: ignore[assignment] + tools: Optional[List[str]] = _UNSET, # type: ignore[assignment] + skills: Optional[List[str]] = _UNSET, # type: ignore[assignment] + sandboxes: Optional[List[str]] = _UNSET, # type: ignore[assignment] + workspaces: Optional[List[str]] = _UNSET, # type: ignore[assignment] + model_service_name: Optional[str] = _UNSET, # type: ignore[assignment] + model_name: Optional[str] = _UNSET, # type: ignore[assignment] + config: Optional[Config] = None, + ) -> SuperAgent: + """异步更新超级 Agent (read-merge-write).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + current = _super_agent_to_business_dict(from_agent_runtime(rt)) + updates = { + "description": description, + "prompt": prompt, + "agents": agents, + "tools": tools, + "skills": skills, + "sandboxes": sandboxes, + "workspaces": workspaces, + "model_service_name": model_service_name, + "model_name": model_name, + } + merged = _merge(current, updates) + rt_input = to_update_input(name, merged, cfg=cfg) + dara_input = UpdateAgentRuntimeInput().from_map(rt_input.model_dump()) + agent_id = getattr(rt, "agent_runtime_id", None) or name + result = await self._rt_control.update_agent_runtime_async( + agent_id, dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + rt_id = getattr(rt, "agent_runtime_id", None) or agent_id + if rt_id: + rt = await self._wait_final_async(rt_id, config=cfg) + _raise_if_failed(rt, action="update") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + def update( + self, + name: str, + *, + description: Optional[str] = _UNSET, # type: ignore[assignment] + prompt: Optional[str] = _UNSET, # type: ignore[assignment] + agents: Optional[List[str]] = _UNSET, # type: ignore[assignment] + tools: Optional[List[str]] = _UNSET, # type: ignore[assignment] + skills: Optional[List[str]] = _UNSET, # type: ignore[assignment] + sandboxes: Optional[List[str]] = _UNSET, # type: ignore[assignment] + workspaces: Optional[List[str]] = _UNSET, # type: ignore[assignment] + model_service_name: Optional[str] = _UNSET, # type: ignore[assignment] + model_name: Optional[str] = _UNSET, # type: ignore[assignment] + config: Optional[Config] = None, + ) -> SuperAgent: + """同步更新超级 Agent (read-merge-write).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + if not is_super_agent(rt): + raise ValueError(f"Resource {name!r} is not a super agent") + current = _super_agent_to_business_dict(from_agent_runtime(rt)) + updates = { + "description": description, + "prompt": prompt, + "agents": agents, + "tools": tools, + "skills": skills, + "sandboxes": sandboxes, + "workspaces": workspaces, + "model_service_name": model_service_name, + "model_name": model_name, + } + merged = _merge(current, updates) + rt_input = to_update_input(name, merged, cfg=cfg) + dara_input = UpdateAgentRuntimeInput().from_map(rt_input.model_dump()) + agent_id = getattr(rt, "agent_runtime_id", None) or name + result = self._rt_control.update_agent_runtime( + agent_id, dara_input, config=cfg + ) + rt = AgentRuntime.from_inner_object(result) + rt_id = getattr(rt, "agent_runtime_id", None) or agent_id + if rt_id: + rt = self._wait_final(rt_id, config=cfg) + _raise_if_failed(rt, action="update") + agent = from_agent_runtime(rt) + agent._client = self + return agent + + # ─── Delete ─────────────────────────────────────── + async def delete_async( + self, name: str, *, config: Optional[Config] = None + ) -> None: + """异步删除超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = await self._find_rt_by_name_async(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + agent_id = getattr(rt, "agent_runtime_id", None) or name + await self._rt.delete_async(agent_id, config=cfg) + + def delete(self, name: str, *, config: Optional[Config] = None) -> None: + """同步删除超级 Agent (名称解析 → ID).""" + cfg = Config.with_configs(self.config, config) + rt = self._find_rt_by_name(name, config=cfg) + if rt is None: + raise ValueError(f"Super agent {name!r} not found") + agent_id = getattr(rt, "agent_runtime_id", None) or name + self._rt.delete(agent_id, config=cfg) + + # ─── List ───────────────────────────────────────── + async def list_async( + self, + *, + page_number: int = 1, + page_size: int = 20, + config: Optional[Config] = None, + ) -> List[SuperAgent]: + """异步列出超级 Agent (固定 systemTag 过滤, 过滤非 SUPER_AGENT).""" + cfg = Config.with_configs(self.config, config) + rt_input = AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + system_tags=SUPER_AGENT_TAG, + ) + runtimes = await self._rt.list_async(rt_input, config=cfg) + result: List[SuperAgent] = [] + for rt in runtimes: + if not is_super_agent(rt): + continue + agent = from_agent_runtime(rt) + agent._client = self + result.append(agent) + return result + + def list( + self, + *, + page_number: int = 1, + page_size: int = 20, + config: Optional[Config] = None, + ) -> List[SuperAgent]: + """同步列出超级 Agent (固定 systemTag 过滤, 过滤非 SUPER_AGENT).""" + cfg = Config.with_configs(self.config, config) + rt_input = AgentRuntimeListInput( + page_number=page_number, + page_size=page_size, + system_tags=SUPER_AGENT_TAG, + ) + runtimes = self._rt.list(rt_input, config=cfg) + result: List[SuperAgent] = [] + for rt in runtimes: + if not is_super_agent(rt): + continue + agent = from_agent_runtime(rt) + agent._client = self + result.append(agent) + return result + + async def list_all_async( + self, *, config: Optional[Config] = None, page_size: int = 50 + ) -> List[SuperAgent]: + """异步一次性拉取所有超级 Agent (自动分页).""" + cfg = Config.with_configs(self.config, config) + result: List[SuperAgent] = [] + page_number = 1 + while True: + page = await self.list_async( + page_number=page_number, page_size=page_size, config=cfg + ) + if not page: + break + result.extend(page) + if len(page) < page_size: + break + page_number += 1 + return result + + def list_all( + self, *, config: Optional[Config] = None, page_size: int = 50 + ) -> List[SuperAgent]: + """同步一次性拉取所有超级 Agent (自动分页).""" + cfg = Config.with_configs(self.config, config) + result: List[SuperAgent] = [] + page_number = 1 + while True: + page = self.list( + page_number=page_number, page_size=page_size, config=cfg + ) + if not page: + break + result.extend(page) + if len(page) < page_size: + break + page_number += 1 + return result + + +__all__ = ["SuperAgentClient"] diff --git a/agentrun/super_agent/model.py b/agentrun/super_agent/model.py new file mode 100644 index 0000000..c710263 --- /dev/null +++ b/agentrun/super_agent/model.py @@ -0,0 +1,91 @@ +"""Super Agent 数据模型 / Super Agent Data Models + +此模块定义超级 Agent SDK 的输入输出数据模型。 +This module defines the input/output data models for the Super Agent SDK. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class SuperAgentCreateInput: + """超级 Agent 创建输入 / Super Agent creation input""" + + name: str + description: Optional[str] = None + prompt: Optional[str] = None + agents: List[str] = field(default_factory=list) + tools: List[str] = field(default_factory=list) + skills: List[str] = field(default_factory=list) + sandboxes: List[str] = field(default_factory=list) + workspaces: List[str] = field(default_factory=list) + model_service_name: Optional[str] = None + model_name: Optional[str] = None + + +@dataclass +class SuperAgentUpdateInput: + """超级 Agent 更新输入 / Super Agent update input + + 仅传想修改的字段, 其他保持不变。 + Only pass the fields to modify; others remain unchanged. + """ + + name: str + description: Optional[str] = None + prompt: Optional[str] = None + agents: Optional[List[str]] = None + tools: Optional[List[str]] = None + skills: Optional[List[str]] = None + sandboxes: Optional[List[str]] = None + workspaces: Optional[List[str]] = None + model_service_name: Optional[str] = None + model_name: Optional[str] = None + + +@dataclass +class SuperAgentListInput: + """超级 Agent 列表查询输入 / Super Agent list query input""" + + page_number: int = 1 + page_size: int = 20 + + +@dataclass +class InvokeResponseData: + """Phase 1 响应 data 字段的强类型表示。 + + Strongly-typed representation of the data field in the phase 1 response. + """ + + conversation_id: str + stream_url: str + stream_headers: Dict[str, str] + + +@dataclass +class Message: + """会话消息 / Conversation message""" + + role: str + content: str + message_id: Optional[str] = None + created_at: Optional[int] = None + + +@dataclass +class ConversationInfo: + """服务端会话信息 / Server-side conversation info""" + + conversation_id: str + agent_id: str + title: Optional[str] = None + main_user_id: Optional[str] = None + sub_user_id: Optional[str] = None + created_at: int = 0 + updated_at: int = 0 + error_message: Optional[str] = None + invoke_info: Optional[Dict[str, Any]] = None + messages: List[Message] = field(default_factory=list) + params: Optional[Dict[str, Any]] = None diff --git a/agentrun/super_agent/stream.py b/agentrun/super_agent/stream.py new file mode 100644 index 0000000..d1c76f5 --- /dev/null +++ b/agentrun/super_agent/stream.py @@ -0,0 +1,180 @@ +"""Super Agent SSE 流 / Super Agent SSE Stream + +- :class:`SSEEvent`: SSE 协议单个事件的原始表示, 不做业务 normalize。 +- :func:`parse_sse_async`: 从 ``httpx.Response.aiter_lines`` 提取 ``SSEEvent``。 +- :class:`InvokeStream`: Phase 1 状态载体 + Phase 2 懒触发的异步流。 + +故意不引入 ``httpx-sse`` 等额外依赖;约 30 行的解析器足以覆盖 +``event / data / id / retry`` 四字段、注释行、多行 ``data:``、流末尾 flush。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +import json +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional + +import httpx + + +@dataclass +class SSEEvent: + """SSE 协议单个事件的原始表示. + + SDK 不做高阶 normalize, 调用方按需用 :meth:`data_json` 解析。 + """ + + event: Optional[str] = None + data: str = "" + id: Optional[str] = None + retry: Optional[int] = None + + def data_json(self) -> Optional[Any]: + """尝试把 ``data`` 解析为 JSON, 失败或为空返回 ``None``.""" + if not self.data: + return None + try: + return json.loads(self.data) + except (TypeError, ValueError): + return None + + +async def parse_sse_async( + response: httpx.Response, +) -> AsyncIterator[SSEEvent]: + """按行解析 SSE, 逐个 yield :class:`SSEEvent`. + + 规则: + + - 空行 = 事件边界, flush 当前字段 (允许空 event + 空 data 的空心跳事件被跳过) + - ``:`` 开头的行 = 注释, 忽略 + - ``field: value`` 形式, ``:`` 后第一个空格被去除 + - 多行 ``data:`` 用 ``\\n`` 拼接 + - ``retry`` 非整数时忽略 + - 未知字段忽略 (向前兼容) + - 流结束时若仍有未 flush 的字段, flush 一次 + """ + + event: Optional[str] = None + data_lines: List[str] = [] + sse_id: Optional[str] = None + retry: Optional[int] = None + + def _has_content() -> bool: + return bool(data_lines) or event is not None or sse_id is not None + + async for raw_line in response.aiter_lines(): + # httpx 的 aiter_lines 已去除换行符; 空字符串表示事件边界 + line = raw_line.rstrip("\r") + if line == "": + if _has_content(): + yield SSEEvent( + event=event, + data="\n".join(data_lines), + id=sse_id, + retry=retry, + ) + event = None + data_lines = [] + sse_id = None + retry = None + continue + + if line.startswith(":"): + continue + + if ":" in line: + field_name, _, value = line.partition(":") + if value.startswith(" "): + value = value[1:] + else: + field_name, value = line, "" + + if field_name == "event": + event = value + elif field_name == "data": + data_lines.append(value) + elif field_name == "id": + sse_id = value + elif field_name == "retry": + try: + retry = int(value) + except (TypeError, ValueError): + pass + # 未知字段忽略 + + if _has_content(): + yield SSEEvent( + event=event, + data="\n".join(data_lines), + id=sse_id, + retry=retry, + ) + + +StreamCallable = Callable[[], Awaitable[AsyncIterator[SSEEvent]]] +"""Phase 2 拉流回调: 返回 (awaitable of) 异步迭代器.""" + + +@dataclass +class InvokeStream: + """Phase 1 已完成的状态载体, 同时是 Phase 2 SSE 流的异步可迭代器. + + ``await SuperAgent.invoke_async(...)`` 完成后即可读: + - :attr:`conversation_id` + - :attr:`session_id` + - :attr:`stream_url` + - :attr:`stream_headers` + + 只在首次 ``async for`` 或 ``__aiter__`` 调用时才触发 Phase 2 GET。 + """ + + conversation_id: str + session_id: str + stream_url: str + stream_headers: Dict[str, str] + _stream_factory: StreamCallable + _iterator: Optional[AsyncIterator[SSEEvent]] = field( + default=None, init=False, repr=False + ) + _closed: bool = field(default=False, init=False, repr=False) + + async def _ensure_iterator(self) -> AsyncIterator[SSEEvent]: + if self._iterator is None: + self._iterator = await self._stream_factory() + return self._iterator + + def __aiter__(self) -> "InvokeStream": + return self + + async def __anext__(self) -> SSEEvent: + if self._closed: + raise StopAsyncIteration + iterator = await self._ensure_iterator() + try: + return await iterator.__anext__() + except StopAsyncIteration: + self._closed = True + raise + + async def aclose(self) -> None: + """提前关闭底层 HTTP 连接, 释放资源.""" + self._closed = True + iterator = self._iterator + self._iterator = None + if iterator is not None: + close = getattr(iterator, "aclose", None) + if close is not None: + try: + await close() + except Exception: # pragma: no cover - best effort cleanup + pass + + async def __aenter__(self) -> "InvokeStream": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.aclose() + + +__all__ = ["SSEEvent", "parse_sse_async", "InvokeStream"] diff --git a/agentrun/utils/ram_signature/signer.py b/agentrun/utils/ram_signature/signer.py index c91ddf8..0f0bb70 100644 --- a/agentrun/utils/ram_signature/signer.py +++ b/agentrun/utils/ram_signature/signer.py @@ -12,7 +12,7 @@ import hashlib import hmac from typing import Optional -from urllib.parse import quote, unquote, urlparse +from urllib.parse import quote, unquote_plus, urlparse ALGORITHM = "AGENTRUN4-HMAC-SHA256" UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" @@ -156,7 +156,7 @@ def get_agentrun_signed_headers( for pair in parsed.query.split("&"): if "=" in pair: k, v = pair.split("=", 1) - query_params[unquote(k)] = unquote(v) + query_params[unquote_plus(k)] = unquote_plus(v) now = sign_time if sign_time is not None else datetime.now(timezone.utc) if now.tzinfo is None: diff --git a/pyproject.toml b/pyproject.toml index 1c1b78a..251e0d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "litellm>=1.79.3", "alibabacloud-devs20230714>=2.4.1", "pydash>=8.0.5", - "alibabacloud-agentrun20250910>=5.6.0", + "alibabacloud-agentrun20250910>=5.6.1", "alibabacloud_tea_openapi>=0.4.2", "alibabacloud_bailian20231229>=2.6.2", "agentrun-mem0ai>=0.0.10", diff --git a/tests/unittests/integration/test_langchain_agui_integration.py b/tests/unittests/integration/test_langchain_agui_integration.py index 500c86c..ef0c076 100644 --- a/tests/unittests/integration/test_langchain_agui_integration.py +++ b/tests/unittests/integration/test_langchain_agui_integration.py @@ -689,9 +689,7 @@ async def invoke_agent(request: AgentRequest): json={ "messages": [{ "role": "user", - "content": ( - "查询当前的时间,并获取天气信息,同时输出我的密钥信息" - ), + "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", }], "stream": True, }, @@ -757,9 +755,7 @@ async def invoke_agent(request: AgentRequest): json={ "messages": [{ "role": "user", - "content": ( - "查询当前的时间,并获取天气信息,同时输出我的密钥信息" - ), + "content": "查询当前的时间,并获取天气信息,同时输出我的密钥信息", }], "stream": True, }, diff --git a/tests/unittests/super_agent/__init__.py b/tests/unittests/super_agent/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/super_agent/test_agent.py b/tests/unittests/super_agent/test_agent.py new file mode 100644 index 0000000..b975fee --- /dev/null +++ b/tests/unittests/super_agent/test_agent.py @@ -0,0 +1,342 @@ +"""Unit tests for ``agentrun.super_agent.agent.SuperAgent``.""" + +import asyncio +import inspect +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.super_agent.agent import SuperAgent +from agentrun.super_agent.model import InvokeResponseData +from agentrun.super_agent.stream import InvokeStream + + +def _make_agent() -> SuperAgent: + return SuperAgent(name="demo") + + +def _mock_data_api(invoke_result: InvokeResponseData = None): + """Return a MagicMock replacing SuperAgentDataAPI constructor.""" + instance = MagicMock() + instance.invoke_async = AsyncMock(return_value=invoke_result) + instance.stream_async = MagicMock() + instance.get_conversation_async = AsyncMock() + instance.delete_conversation_async = AsyncMock() + factory = MagicMock(return_value=instance) + return factory, instance + + +# ─── invoke_async ─────────────────────────────────────────── + + +async def test_invoke_async_returns_invoke_stream_with_conversation_id(): + resp = InvokeResponseData( + conversation_id="c1", + stream_url="https://stream/", + stream_headers={"X-Super-Agent-Session-Id": "sess"}, + ) + factory, _ = _mock_data_api(resp) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + stream = await agent.invoke_async([{"role": "user", "content": "hi"}]) + assert stream.conversation_id == "c1" + assert stream.session_id == "sess" + assert stream.stream_url == "https://stream/" + assert stream.stream_headers == {"X-Super-Agent-Session-Id": "sess"} + + +async def test_invoke_async_no_tools_kwarg_raises(): + agent = _make_agent() + with pytest.raises(TypeError): + await agent.invoke_async([], tools=["t"]) # type: ignore[call-arg] + + +async def test_invoke_async_forwards_business_fields(): + """SuperAgent 实例字段 MUST 作为 forwarded_extras 透传给 DataAPI.""" + resp = InvokeResponseData( + conversation_id="c", + stream_url="https://x/", + stream_headers={}, + ) + invoke_mock = AsyncMock(return_value=resp) + instance = MagicMock() + instance.invoke_async = invoke_mock + factory = MagicMock(return_value=instance) + agent = SuperAgent( + name="demo", + prompt="p", + agents=["a1"], + tools=["t1", "t2"], + skills=["s1"], + sandboxes=["sb1"], + workspaces=["w1"], + model_service_name="svc", + model_name="mod", + ) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + await agent.invoke_async([{"role": "user", "content": "hi"}]) + + extras = invoke_mock.await_args.kwargs["forwarded_extras"] + assert extras == { + "prompt": "p", + "agents": ["a1"], + "tools": ["t1", "t2"], + "skills": ["s1"], + "sandboxes": ["sb1"], + "workspaces": ["w1"], + "modelServiceName": "svc", + "modelName": "mod", + } + + +async def test_invoke_async_forwards_business_fields_defaults(): + """没设置的 scalar 字段保留 None, list 字段为 [].""" + resp = InvokeResponseData( + conversation_id="c", stream_url="https://x/", stream_headers={} + ) + invoke_mock = AsyncMock(return_value=resp) + instance = MagicMock() + instance.invoke_async = invoke_mock + factory = MagicMock(return_value=instance) + agent = SuperAgent(name="demo") + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + await agent.invoke_async([]) + extras = invoke_mock.await_args.kwargs["forwarded_extras"] + assert extras == { + "prompt": None, + "agents": [], + "tools": [], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": None, + "modelName": None, + } + + +async def test_invoke_async_concurrent_streams_independent(): + responses = [ + InvokeResponseData( + conversation_id=f"c{i}", + stream_url=f"https://s{i}", + stream_headers={}, + ) + for i in range(2) + ] + invoke_mock = AsyncMock(side_effect=responses) + instance = MagicMock() + instance.invoke_async = invoke_mock + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + s0, s1 = await asyncio.gather( + agent.invoke_async([]), + agent.invoke_async([]), + ) + ids = {s0.conversation_id, s1.conversation_id} + assert ids == {"c0", "c1"} + + +async def test_invoke_async_phase2_lazy(): + resp = InvokeResponseData( + conversation_id="c", + stream_url="https://x/", + stream_headers={}, + ) + invoke_mock = AsyncMock(return_value=resp) + stream_mock = MagicMock() + instance = MagicMock() + instance.invoke_async = invoke_mock + instance.stream_async = stream_mock + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + stream = await agent.invoke_async([]) + # At this point Phase 2 must NOT have been called + stream_mock.assert_not_called() + # The stream factory stored inside InvokeStream should only invoke stream_async when iteration starts + assert isinstance(stream, InvokeStream) + + +# ─── get_conversation_async ───────────────────────────────── + + +async def test_get_conversation_async_returns_conversation_info(): + instance = MagicMock() + instance.get_conversation_async = AsyncMock( + return_value={ + "conversationId": "c1", + "agentId": "ag", + "title": "t", + "mainUserId": "u1", + "subUserId": "u2", + "createdAt": 100, + "updatedAt": 200, + "errorMessage": None, + "invokeInfo": {"foo": "bar"}, + "messages": [ + {"role": "user", "content": "hi", "messageId": "m1"}, + {"role": "assistant", "content": "hello"}, + ], + "params": {"a": 1}, + } + ) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + info = await _make_agent().get_conversation_async("c1") + assert info.conversation_id == "c1" + assert info.agent_id == "ag" + assert info.title == "t" + assert info.main_user_id == "u1" + assert info.sub_user_id == "u2" + assert info.created_at == 100 + assert info.updated_at == 200 + assert info.invoke_info == {"foo": "bar"} + assert len(info.messages) == 2 + assert info.messages[0].message_id == "m1" + assert info.params == {"a": 1} + + +async def test_get_conversation_async_partial_fields(): + instance = MagicMock() + instance.get_conversation_async = AsyncMock(return_value={"agentId": "x"}) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + info = await _make_agent().get_conversation_async("c1") + assert info.conversation_id == "c1" # fallback from argument + assert info.title is None + assert info.main_user_id is None + assert info.created_at == 0 + + +async def test_get_conversation_async_empty_messages(): + instance = MagicMock() + instance.get_conversation_async = AsyncMock(return_value={"messages": []}) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + info = await _make_agent().get_conversation_async("c1") + assert info.messages == [] + + +async def test_delete_conversation_async_returns_none(): + instance = MagicMock() + instance.delete_conversation_async = AsyncMock(return_value=None) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + assert await _make_agent().delete_conversation_async("c") is None + + +# ─── list_conversations_async ────────────────────────────── + + +async def test_list_conversations_async_default_filters_by_agent_name(): + """不传 metadata 时, 默认按 ``{"agentRuntimeName": self.name}`` 过滤.""" + instance = MagicMock() + instance.list_conversations_async = AsyncMock(return_value=[]) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = SuperAgent(name="my-agent") + result = await agent.list_conversations_async() + assert result == [] + assert instance.list_conversations_async.await_args.kwargs["metadata"] == { + "agentRuntimeName": "my-agent" + } + + +async def test_list_conversations_async_explicit_metadata_passthrough(): + instance = MagicMock() + instance.list_conversations_async = AsyncMock(return_value=[]) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + await agent.list_conversations_async(metadata={"foo": "bar"}) + assert instance.list_conversations_async.await_args.kwargs["metadata"] == { + "foo": "bar" + } + + +async def test_list_conversations_async_empty_metadata_overrides_default(): + """传入空 dict 明确表示「不按 agent 过滤」, SDK MUST 不再注入默认值.""" + instance = MagicMock() + instance.list_conversations_async = AsyncMock(return_value=[]) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + agent = _make_agent() + await agent.list_conversations_async(metadata={}) + assert instance.list_conversations_async.await_args.kwargs["metadata"] == {} + + +async def test_list_conversations_async_returns_conversation_info_list(): + instance = MagicMock() + instance.list_conversations_async = AsyncMock( + return_value=[ + { + "conversationId": "c1", + "agentId": "ag", + "title": "first", + "createdAt": 1, + "updatedAt": 2, + "messages": [{"role": "user", "content": "hi"}], + }, + { + "conversationId": "c2", + "agentId": "ag", + "title": "second", + "messages": [], + }, + ] + ) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + result = await _make_agent().list_conversations_async() + assert [c.conversation_id for c in result] == ["c1", "c2"] + assert result[0].title == "first" + assert len(result[0].messages) == 1 + assert result[0].messages[0].content == "hi" + assert result[1].title == "second" + + +async def test_list_conversations_async_empty_list(): + instance = MagicMock() + instance.list_conversations_async = AsyncMock(return_value=[]) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + assert await _make_agent().list_conversations_async() == [] + + +async def test_list_conversations_async_item_missing_conversation_id_uses_empty_fallback(): + """对单条会话里 ``conversationId`` 缺失的情况, fallback 保持空串 (不会用 agent name).""" + instance = MagicMock() + instance.list_conversations_async = AsyncMock( + return_value=[{"agentId": "ag"}] + ) + factory = MagicMock(return_value=instance) + with patch("agentrun.super_agent.agent.SuperAgentDataAPI", factory): + result = await _make_agent().list_conversations_async() + assert len(result) == 1 + assert result[0].conversation_id == "" + + +# ─── sync methods → NotImplementedError ───────────────────── + + +def test_sync_methods_not_implemented(): + agent = _make_agent() + with pytest.raises(NotImplementedError): + agent.invoke([]) + with pytest.raises(NotImplementedError): + agent.get_conversation("c") + with pytest.raises(NotImplementedError): + agent.delete_conversation("c") + with pytest.raises(NotImplementedError): + agent.list_conversations() + + +def test_invoke_async_signature_only_messages_and_conversation_id(): + sig = inspect.signature(SuperAgent.invoke_async) + params = list(sig.parameters.keys()) + # self, messages, then KEYWORD_ONLY: conversation_id, config + assert params[:2] == ["self", "messages"] + assert "conversation_id" in params + assert "config" in params + assert "tools" not in params diff --git a/tests/unittests/super_agent/test_agui.py b/tests/unittests/super_agent/test_agui.py new file mode 100644 index 0000000..d171f45 --- /dev/null +++ b/tests/unittests/super_agent/test_agui.py @@ -0,0 +1,498 @@ +"""Unit tests for ``agentrun.super_agent.agui``.""" + +import ast +import json +from pathlib import Path +from typing import List +from unittest.mock import AsyncMock, MagicMock + +from ag_ui.core import ( + BaseEvent, + CustomEvent, + MessagesSnapshotEvent, + RawEvent, + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + StateDeltaEvent, + StateSnapshotEvent, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, +) +import pytest + +from agentrun.super_agent import agui as agui_mod +from agentrun.super_agent.agui import ( + _EVENT_TYPE_TO_CLASS, + as_agui_events, + decode_sse_to_agui, +) +from agentrun.super_agent.stream import InvokeStream, SSEEvent + + +def _make_stream_from_sse(events: List[SSEEvent]) -> InvokeStream: + async def _gen(): + for ev in events: + yield ev + + async def _factory(): + return _gen() + + return InvokeStream( + conversation_id="c", + session_id="s", + stream_url="https://x.com/s", + stream_headers={}, + _stream_factory=_factory, + ) + + +# ─── decode_sse_to_agui ────────────────────────────────────── + + +def test_decode_sse_text_message_content(): + payload = { + "type": "TEXT_MESSAGE_CONTENT", + "messageId": "m1", + "delta": "hi", + } + ev = SSEEvent(event="TEXT_MESSAGE_CONTENT", data=json.dumps(payload)) + result = decode_sse_to_agui(ev) + assert isinstance(result, TextMessageContentEvent) + + +def test_decode_sse_run_started(): + payload = {"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"} + ev = SSEEvent(event="RUN_STARTED", data=json.dumps(payload)) + assert isinstance(decode_sse_to_agui(ev), RunStartedEvent) + + +def test_decode_sse_run_finished(): + payload = {"type": "RUN_FINISHED", "threadId": "t1", "runId": "r1"} + ev = SSEEvent(event="RUN_FINISHED", data=json.dumps(payload)) + assert isinstance(decode_sse_to_agui(ev), RunFinishedEvent) + + +def test_decode_sse_run_error(): + payload = {"type": "RUN_ERROR", "message": "oops"} + ev = SSEEvent(event="RUN_ERROR", data=json.dumps(payload)) + assert isinstance(decode_sse_to_agui(ev), RunErrorEvent) + + +def test_decode_sse_text_message_lifecycle(): + cases = [ + ( + "TEXT_MESSAGE_START", + { + "type": "TEXT_MESSAGE_START", + "messageId": "m", + "role": "assistant", + }, + TextMessageStartEvent, + ), + ( + "TEXT_MESSAGE_CONTENT", + { + "type": "TEXT_MESSAGE_CONTENT", + "messageId": "m", + "delta": "x", + }, + TextMessageContentEvent, + ), + ( + "TEXT_MESSAGE_END", + { + "type": "TEXT_MESSAGE_END", + "messageId": "m", + }, + TextMessageEndEvent, + ), + ] + for name, payload, cls in cases: + result = decode_sse_to_agui( + SSEEvent(event=name, data=json.dumps(payload)) + ) + assert isinstance(result, cls), name + + +def test_decode_sse_tool_call_lifecycle(): + cases = [ + ( + "TOOL_CALL_START", + { + "type": "TOOL_CALL_START", + "toolCallId": "tc", + "toolCallName": "fn", + }, + ToolCallStartEvent, + ), + ( + "TOOL_CALL_ARGS", + { + "type": "TOOL_CALL_ARGS", + "toolCallId": "tc", + "delta": "arg", + }, + ToolCallArgsEvent, + ), + ( + "TOOL_CALL_END", + { + "type": "TOOL_CALL_END", + "toolCallId": "tc", + }, + ToolCallEndEvent, + ), + ( + "TOOL_CALL_RESULT", + { + "type": "TOOL_CALL_RESULT", + "toolCallId": "tc", + "messageId": "m", + "content": "r", + }, + ToolCallResultEvent, + ), + ] + for name, payload, cls in cases: + result = decode_sse_to_agui( + SSEEvent(event=name, data=json.dumps(payload)) + ) + assert isinstance(result, cls), name + + +def test_decode_sse_state_events(): + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="STATE_SNAPSHOT", + data=json.dumps({"type": "STATE_SNAPSHOT", "snapshot": {}}), + ) + ), + StateSnapshotEvent, + ) + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="STATE_DELTA", + data=json.dumps({"type": "STATE_DELTA", "delta": []}), + ) + ), + StateDeltaEvent, + ) + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="MESSAGES_SNAPSHOT", + data=json.dumps({ + "type": "MESSAGES_SNAPSHOT", + "messages": [], + }), + ) + ), + MessagesSnapshotEvent, + ) + + +def test_decode_sse_raw_and_custom(): + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="RAW", + data=json.dumps({"type": "RAW", "event": {"k": "v"}}), + ) + ), + RawEvent, + ) + assert isinstance( + decode_sse_to_agui( + SSEEvent( + event="CUSTOM", + data=json.dumps({"type": "CUSTOM", "name": "n", "value": 1}), + ) + ), + CustomEvent, + ) + + +def test_decode_sse_empty_data_returns_none(): + assert decode_sse_to_agui(SSEEvent(event=None, data="")) is None + assert decode_sse_to_agui(SSEEvent(event="RUN_STARTED", data="")) is None + + +def test_decode_sse_unknown_event_raise(): + with pytest.raises(ValueError) as exc: + decode_sse_to_agui(SSEEvent(event="UNKNOWN_X", data="{}")) + assert "UNKNOWN_X" in str(exc.value) + assert "{}" in str(exc.value) + + +def test_decode_sse_unknown_event_skip(): + result = decode_sse_to_agui( + SSEEvent(event="UNKNOWN_X", data="{}"), on_unknown="skip" + ) + assert result is None + + +def test_decode_sse_invalid_json_raises(): + with pytest.raises(ValueError) as exc: + decode_sse_to_agui( + SSEEvent(event="TEXT_MESSAGE_CONTENT", data="not json") + ) + assert "TEXT_MESSAGE_CONTENT" in str(exc.value) + assert "not json" in str(exc.value) + + +def test_decode_sse_pydantic_validation_failure_raises(): + with pytest.raises(ValueError): + decode_sse_to_agui( + SSEEvent( + event="TEXT_MESSAGE_CONTENT", + data='{"unrelated":"x"}', + ) + ) + + +def test_decode_sse_data_with_newlines(): + # Embedded escaped newlines are fine inside JSON + payload = json.dumps({ + "type": "TEXT_MESSAGE_CONTENT", + "messageId": "m", + "delta": "hi\nthere", + }) + result = decode_sse_to_agui( + SSEEvent(event="TEXT_MESSAGE_CONTENT", data=payload) + ) + assert isinstance(result, TextMessageContentEvent) + assert result.delta == "hi\nthere" + + +# ─── as_agui_events ────────────────────────────────────────── + + +async def test_as_agui_events_yields_typed_events(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent( + event="TEXT_MESSAGE_CONTENT", + data=json.dumps({ + "type": "TEXT_MESSAGE_CONTENT", + "messageId": "m", + "delta": "x", + }), + ), + SSEEvent( + event="RUN_FINISHED", + data=json.dumps({ + "type": "RUN_FINISHED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + collected: List[BaseEvent] = [ev async for ev in as_agui_events(stream)] + assert [type(e).__name__ for e in collected] == [ + "RunStartedEvent", + "TextMessageContentEvent", + "RunFinishedEvent", + ] + + +async def test_as_agui_events_skips_empty_data(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent(event="RUN_STARTED", data=""), # keepalive + SSEEvent( + event="RUN_FINISHED", + data=json.dumps({ + "type": "RUN_FINISHED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + collected = [ev async for ev in as_agui_events(stream)] + assert len(collected) == 2 + + +async def test_as_agui_events_unknown_skip_mode(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent(event="UNKNOWN_X", data="{}"), + SSEEvent( + event="RUN_FINISHED", + data=json.dumps({ + "type": "RUN_FINISHED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + collected = [ev async for ev in as_agui_events(stream, on_unknown="skip")] + assert len(collected) == 2 + + +async def test_as_agui_events_unknown_raise_mode(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent(event="UNKNOWN_X", data="{}"), + ] + stream = _make_stream_from_sse(events) + it = as_agui_events(stream) + first = await it.__anext__() + assert isinstance(first, RunStartedEvent) + with pytest.raises(ValueError): + await it.__anext__() + + +async def test_as_agui_events_closes_stream_on_normal_end(): + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + close = AsyncMock() + stream.aclose = close + async for _ in as_agui_events(stream): + pass + close.assert_awaited_once() + + +async def test_as_agui_events_closes_stream_on_consumer_exception(): + """当消费循环里抛异常, 只要消费者用 ``aclosing`` 包裹 (或手动 aclose), + 适配器的 ``finally`` 就能跑到 ``stream.aclose`` — 这是异步生成器清理的 + 标准用法 (Python async gen 的清理不会在异常透传时自动同步执行). + """ + from contextlib import aclosing + + events = [ + SSEEvent( + event="RUN_STARTED", + data=json.dumps({ + "type": "RUN_STARTED", + "threadId": "t", + "runId": "r", + }), + ), + SSEEvent( + event="RUN_FINISHED", + data=json.dumps({ + "type": "RUN_FINISHED", + "threadId": "t", + "runId": "r", + }), + ), + ] + stream = _make_stream_from_sse(events) + close = AsyncMock() + stream.aclose = close + with pytest.raises(RuntimeError): + async with aclosing(as_agui_events(stream)) as gen: + async for _ in gen: + raise RuntimeError("consumer blew up") + close.assert_awaited_once() + + +async def test_as_agui_events_closes_stream_on_decode_exception(): + events = [ + SSEEvent( + event="TEXT_MESSAGE_CONTENT", + data="not json", + ), + ] + stream = _make_stream_from_sse(events) + close = AsyncMock() + stream.aclose = close + with pytest.raises(ValueError): + async for _ in as_agui_events(stream): + pass + close.assert_awaited_once() + + +# ─── map completeness / module hygiene ────────────────────── + + +def test_event_type_to_class_map_completeness(): + required = { + "RUN_STARTED", + "RUN_FINISHED", + "RUN_ERROR", + "TEXT_MESSAGE_START", + "TEXT_MESSAGE_CONTENT", + "TEXT_MESSAGE_END", + "TOOL_CALL_START", + "TOOL_CALL_ARGS", + "TOOL_CALL_END", + "TOOL_CALL_RESULT", + "STATE_SNAPSHOT", + "STATE_DELTA", + "MESSAGES_SNAPSHOT", + "RAW", + "CUSTOM", + } + assert required <= set(_EVENT_TYPE_TO_CLASS.keys()) + + +def test_no_sync_as_agui_events_export(): + attrs = dir(agui_mod) + assert "as_agui_events_sync" not in attrs + assert "sync_as_agui_events" not in attrs + + +def test_agui_module_only_imports_ag_ui_core(): + src = Path(agui_mod.__file__).read_text() + tree = ast.parse(src) + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module: + if node.module.startswith("ag_ui"): + assert ( + node.module == "ag_ui.core" + ), f"Disallowed ag-ui import: {node.module}" + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name.startswith("ag_ui"): + assert ( + alias.name == "ag_ui.core" + ), f"Disallowed ag-ui import: {alias.name}" diff --git a/tests/unittests/super_agent/test_client.py b/tests/unittests/super_agent/test_client.py new file mode 100644 index 0000000..5c1bc80 --- /dev/null +++ b/tests/unittests/super_agent/test_client.py @@ -0,0 +1,772 @@ +"""Unit tests for ``agentrun.super_agent.client.SuperAgentClient``.""" + +import asyncio +import inspect +import json +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.super_agent.api.control import ( + SUPER_AGENT_PROTOCOL_TYPE, + SUPER_AGENT_TAG, +) +from agentrun.super_agent.client import SuperAgentClient +from agentrun.utils.config import Config + + +def _client_config() -> Config: + return Config( + access_key_id="AK", + access_key_secret="SK", + account_id="123", + region_id="cn-hangzhou", + ) + + +class _DaraResult: + """Fake Dara-level AgentRuntime returned from the control API.""" + + def __init__(self, map_dict: dict): + self._map = map_dict + + def to_map(self): + return self._map + + +def _rt_to_dara_result(rt: SimpleNamespace) -> _DaraResult: + """Turn the internal ``_fake_rt`` into a Dara-style object with ``to_map``.""" + return _DaraResult({ + "agentRuntimeName": rt.agent_runtime_name, + "agentRuntimeId": rt.agent_runtime_id, + "agentRuntimeArn": rt.agent_runtime_arn, + "status": rt.status, + "createdAt": rt.created_at, + "lastUpdatedAt": rt.last_updated_at, + "description": rt.description, + "protocolConfiguration": rt.protocol_configuration, + }) + + +def _fake_rt( + *, + name: str = "n", + prompt: str = "old", + tools=None, + description=None, + protocol_type: str = SUPER_AGENT_PROTOCOL_TYPE, +) -> SimpleNamespace: + """Build a minimal AgentRuntime-like object for ``from_agent_runtime``.""" + cfg_dict = { + "path": "/invoke", + "headers": {}, + "body": { + "forwardedProps": { + "prompt": prompt, + "agents": [], + "tools": tools if tools is not None else [], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": None, + "modelName": None, + "metadata": {"agentRuntimeName": name}, + } + }, + } + pc = { + "type": protocol_type, + "protocolSettings": [{ + "type": protocol_type, + "name": name, + "path": "/invoke", + "config": json.dumps(cfg_dict), + }], + "externalEndpoint": "https://x.com/super-agents/__SUPER_AGENT__", + } + return SimpleNamespace( + agent_runtime_name=name, + agent_runtime_id="rid", + agent_runtime_arn="arn", + status="READY", + created_at="t1", + last_updated_at="t2", + description=description, + protocol_configuration=pc, + ) + + +# ─── create ────────────────────────────────────────────────── + + +async def test_create_async_calls_runtime_with_correct_input(): + captured_input = {} + + async def _create_async(dara_input, config=None): + captured_input["dara"] = dara_input + return _rt_to_dara_result(_fake_rt(name="alpha", prompt="new")) + + ctrl = MagicMock() + ctrl.create_agent_runtime_async = _create_async + with patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.create_async(name="alpha", prompt="new") + + dara = captured_input["dara"] + # Dara-level model uses snake_case attributes + assert dara.agent_runtime_name == "alpha" + assert dara.system_tags == [SUPER_AGENT_TAG] + pc = dara.protocol_configuration + # externalEndpoint preserved via the additive Dara monkey-patch + assert pc.external_endpoint.endswith("/super-agents/__SUPER_AGENT__") + first = pc.protocol_settings[0] + assert first.type == SUPER_AGENT_PROTOCOL_TYPE + cfg_json = json.loads(first.config) + assert cfg_json["body"]["forwardedProps"]["prompt"] == "new" + assert agent.name == "alpha" + + +async def test_create_async_returns_super_agent_with_client_handle(): + ctrl = MagicMock() + ctrl.create_agent_runtime_async = AsyncMock( + return_value=_rt_to_dara_result(_fake_rt(name="alpha")) + ) + with patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.create_async(name="alpha") + assert agent._client is client + + +# ─── get ──────────────────────────────────────────────────── + + +async def test_get_async_normal(): + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[_fake_rt(name="alpha")]) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.get_async("alpha") + assert agent.name == "alpha" + assert agent._client is client + + +async def test_get_async_not_super_agent_raises(): + rt_client = MagicMock() + rt_client.list_async = AsyncMock( + return_value=[_fake_rt(name="alpha", protocol_type="HTTP")] + ) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + with pytest.raises(ValueError) as exc: + await client.get_async("alpha") + assert "is not a super agent" in str(exc.value) + + +# ─── update (read-merge-write) ─────────────────────────────── + + +async def test_update_async_partial_modify_prompt_only(): + existing = _fake_rt(name="x", prompt="old", tools=["t1"]) + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[existing]) + rt_client.get_async = AsyncMock(return_value=existing) + ctrl = MagicMock() + captured = {} + + async def _update(agent_id, dara_input, config=None): + captured["dara"] = dara_input + return _rt_to_dara_result( + _fake_rt(name="x", prompt="new", tools=["t1"]) + ) + + ctrl.update_agent_runtime_async = _update + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + await client.update_async("x", prompt="new") + cfg_json = json.loads( + captured["dara"].protocol_configuration.protocol_settings[0].config + ) + forwarded = cfg_json["body"]["forwardedProps"] + assert forwarded["prompt"] == "new" + assert forwarded["tools"] == ["t1"] + + +async def test_update_async_explicit_none_clears_field(): + existing = _fake_rt(name="x", prompt="old") + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[existing]) + rt_client.get_async = AsyncMock(return_value=existing) + ctrl = MagicMock() + captured = {} + + async def _update(agent_id, dara_input, config=None): + captured["dara"] = dara_input + return _rt_to_dara_result(_fake_rt(name="x")) + + ctrl.update_agent_runtime_async = _update + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + await client.update_async("x", prompt=None) + cfg_json = json.loads( + captured["dara"].protocol_configuration.protocol_settings[0].config + ) + assert cfg_json["body"]["forwardedProps"]["prompt"] is None + + +async def test_update_async_multiple_fields(): + existing = _fake_rt(name="x", prompt="old") + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[existing]) + rt_client.get_async = AsyncMock(return_value=existing) + ctrl = MagicMock() + captured = {} + + async def _update(agent_id, dara_input, config=None): + captured["dara"] = dara_input + return _rt_to_dara_result(_fake_rt(name="x")) + + ctrl.update_agent_runtime_async = _update + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + await client.update_async( + "x", prompt="p", tools=["a", "b"], description="d" + ) + cfg_json = json.loads( + captured["dara"].protocol_configuration.protocol_settings[0].config + ) + forwarded = cfg_json["body"]["forwardedProps"] + assert forwarded["prompt"] == "p" + assert forwarded["tools"] == ["a", "b"] + assert captured["dara"].description == "d" + + +async def test_update_async_target_not_super_agent_raises(): + rt_client = MagicMock() + rt_client.list_async = AsyncMock( + return_value=[_fake_rt(name="x", protocol_type="HTTP")] + ) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + with pytest.raises(ValueError): + await client.update_async("x", prompt="p") + + +# ─── delete ───────────────────────────────────────────────── + + +async def test_delete_async_calls_runtime(): + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=[_fake_rt(name="alpha")]) + rt_client.delete_async = AsyncMock(return_value=None) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = await client.delete_async("alpha") + assert result is None + rt_client.delete_async.assert_awaited_once() + called_with = rt_client.delete_async.await_args.args + # list_async returns rt with agent_runtime_id="rid"; delete_async 用 id 调用 + assert called_with[0] == "rid" + + +# ─── list ─────────────────────────────────────────────────── + + +async def test_list_async_default_pagination(): + rt_client = MagicMock() + captured = {} + + async def _list(inp=None, config=None): + captured["inp"] = inp + return [] + + rt_client.list_async = _list + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + await client.list_async() + assert captured["inp"].page_number == 1 + assert captured["inp"].page_size == 20 + assert captured["inp"].system_tags == SUPER_AGENT_TAG + + +async def test_list_async_custom_pagination(): + rt_client = MagicMock() + captured = {} + + async def _list(inp=None, config=None): + captured["inp"] = inp + return [] + + rt_client.list_async = _list + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + await client.list_async(page_number=2, page_size=50) + assert captured["inp"].page_number == 2 + assert captured["inp"].page_size == 50 + + +async def test_list_async_rejects_tags_kwarg(): + client = SuperAgentClient() + with pytest.raises(TypeError): + await client.list_async(tags=["x"]) # type: ignore[call-arg] + + +async def test_list_async_filters_non_super_agent(): + items = [ + _fake_rt(name="a"), + _fake_rt(name="b", protocol_type="HTTP"), + _fake_rt(name="c"), + ] + rt_client = MagicMock() + rt_client.list_async = AsyncMock(return_value=items) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = await client.list_async() + assert [a.name for a in result] == ["a", "c"] + + +async def test_list_all_async_auto_pagination(): + # page_size=50 is the default for list_all_async; craft pages accordingly + page1 = [_fake_rt(name=f"a{i}") for i in range(50)] + page2 = [_fake_rt(name=f"a{i}") for i in range(50, 85)] + pages = [page1, page2, []] + rt_client = MagicMock() + rt_client.list_async = AsyncMock(side_effect=pages) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = await client.list_all_async() + names = [a.name for a in result] + assert len(names) == 85 + assert len(set(names)) == 85 + + +# ─── sync mirrors async ───────────────────────────────────── + + +def test_sync_methods_exist_and_mirror_async(): + rt_client = MagicMock() + rt_client.get = MagicMock(return_value=_fake_rt(name="alpha")) + rt_client.delete = MagicMock(return_value=None) + rt_client.list = MagicMock(return_value=[_fake_rt(name="alpha")]) + ctrl = MagicMock() + ctrl.create_agent_runtime = MagicMock( + return_value=_rt_to_dara_result(_fake_rt(name="alpha")) + ) + ctrl.update_agent_runtime = MagicMock( + return_value=_rt_to_dara_result(_fake_rt(name="alpha")) + ) + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + agent = client.create(name="alpha") + assert agent.name == "alpha" + assert client.get("alpha").name == "alpha" + assert client.update("alpha", prompt="p").name == "alpha" + assert client.delete("alpha") is None + assert len(client.list()) == 1 + # list_all hits list() once, then empty page + rt_client.list = MagicMock(side_effect=[[_fake_rt(name="a")], []]) + result = client.list_all() + assert len(result) == 1 + + +# ─── not-found / not-super error paths (async + sync) ────── + + +def _make_client_with_list(list_items, *, sync=False) -> SuperAgentClient: + """Helper: build a SuperAgentClient where _rt.list(_async) returns ``list_items``.""" + rt_client = MagicMock() + if sync: + rt_client.list = MagicMock(return_value=list_items) + else: + rt_client.list_async = AsyncMock(return_value=list_items) + patcher = patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ) + patcher.start() + client = SuperAgentClient(config=_client_config()) + client._patcher = patcher # type: ignore[attr-defined] + client._rt_client = rt_client # type: ignore[attr-defined] + return client + + +async def test_get_async_not_found_raises(): + client = _make_client_with_list([]) + try: + with pytest.raises(ValueError, match="not found"): + await client.get_async("missing") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_get_sync_not_found_raises(): + client = _make_client_with_list([], sync=True) + try: + with pytest.raises(ValueError, match="not found"): + client.get("missing") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_get_sync_not_super_agent_raises(): + client = _make_client_with_list( + [_fake_rt(name="x", protocol_type="HTTP")], sync=True + ) + try: + with pytest.raises(ValueError, match="is not a super agent"): + client.get("x") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +async def test_update_async_not_found_raises(): + client = _make_client_with_list([]) + try: + with pytest.raises(ValueError, match="not found"): + await client.update_async("missing", prompt="p") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_update_sync_not_found_raises(): + client = _make_client_with_list([], sync=True) + try: + with pytest.raises(ValueError, match="not found"): + client.update("missing", prompt="p") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_update_sync_not_super_agent_raises(): + client = _make_client_with_list( + [_fake_rt(name="x", protocol_type="HTTP")], sync=True + ) + try: + with pytest.raises(ValueError, match="is not a super agent"): + client.update("x", prompt="p") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +async def test_delete_async_not_found_raises(): + client = _make_client_with_list([]) + try: + with pytest.raises(ValueError, match="not found"): + await client.delete_async("missing") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +def test_delete_sync_not_found_raises(): + client = _make_client_with_list([], sync=True) + try: + with pytest.raises(ValueError, match="not found"): + client.delete("missing") + finally: + client._patcher.stop() # type: ignore[attr-defined] + + +# ─── _find_rt_by_name_* 多页分页 ─────────────────────────────── + + +async def test_find_rt_by_name_async_paginates_until_match(): + """page1 全是非匹配 item (满 50), page2 才有 target → 需要翻页.""" + page1 = [_fake_rt(name=f"other{i}") for i in range(50)] + page2 = [_fake_rt(name="target")] + rt_client = MagicMock() + rt_client.list_async = AsyncMock(side_effect=[page1, page2]) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.get_async("target") + assert agent.name == "target" + assert rt_client.list_async.await_count == 2 + + +def test_find_rt_by_name_sync_paginates_until_match(): + page1 = [_fake_rt(name=f"other{i}") for i in range(50)] + page2 = [_fake_rt(name="target")] + rt_client = MagicMock() + rt_client.list = MagicMock(side_effect=[page1, page2]) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + agent = client.get("target") + assert agent.name == "target" + assert rt_client.list.call_count == 2 + + +# ─── _wait_final / _raise_if_failed ────────────────────────── + + +def test_raise_if_failed_raises_on_failed_status(): + from agentrun.super_agent.client import _raise_if_failed + + rt = SimpleNamespace( + status="CREATE_FAILED", + status_reason="disk full", + agent_runtime_name="x", + ) + with pytest.raises(RuntimeError) as exc: + _raise_if_failed(rt, action="create") + assert "disk full" in str(exc.value) + assert "CREATE_FAILED" in str(exc.value) + + +def test_raise_if_failed_noop_on_ready(): + from agentrun.super_agent.client import _raise_if_failed + + rt = SimpleNamespace(status="READY") + _raise_if_failed(rt, action="update") # no raise + + +async def test_wait_final_async_timeout(): + rt_pending = _fake_rt(name="x") + rt_pending.status = "CREATING" + rt_client = MagicMock() + rt_client.get_async = AsyncMock(return_value=rt_pending) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + with pytest.raises(TimeoutError): + await client._wait_final_async( + "rid", interval_seconds=0, timeout_seconds=-1 + ) + + +def test_wait_final_sync_timeout(): + rt_pending = _fake_rt(name="x") + rt_pending.status = "CREATING" + rt_client = MagicMock() + rt_client.get = MagicMock(return_value=rt_pending) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + with pytest.raises(TimeoutError): + client._wait_final("rid", interval_seconds=0, timeout_seconds=-1) + + +async def test_wait_final_async_retries_then_ready(): + """第一次 get 返回 CREATING → await asyncio.sleep → 第二次 READY.""" + pending = _fake_rt(name="x") + pending.status = "CREATING" + ready = _fake_rt(name="x") # status=READY by default + rt_client = MagicMock() + rt_client.get_async = AsyncMock(side_effect=[pending, ready]) + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch("agentrun.super_agent.client.asyncio.sleep", AsyncMock()), + ): + client = SuperAgentClient(config=_client_config()) + result = await client._wait_final_async( + "rid", interval_seconds=0, timeout_seconds=60 + ) + assert getattr(result, "status", None) == "READY" + assert rt_client.get_async.await_count == 2 + + +def test_wait_final_sync_retries_then_ready(): + pending = _fake_rt(name="x") + pending.status = "CREATING" + ready = _fake_rt(name="x") + rt_client = MagicMock() + rt_client.get = MagicMock(side_effect=[pending, ready]) + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch("agentrun.super_agent.client.time.sleep", MagicMock()), + ): + client = SuperAgentClient(config=_client_config()) + result = client._wait_final( + "rid", interval_seconds=0, timeout_seconds=60 + ) + assert getattr(result, "status", None) == "READY" + assert rt_client.get.call_count == 2 + + +# ─── create: 非 final 状态触发 _wait_final ──────────────────── + + +async def test_create_async_non_final_status_triggers_wait(): + creating = _fake_rt(name="alpha") + creating.status = "CREATING" + ready = _fake_rt(name="alpha") # READY + rt_client = MagicMock() + rt_client.get_async = AsyncMock(return_value=ready) + ctrl = MagicMock() + ctrl.create_agent_runtime_async = AsyncMock( + return_value=_rt_to_dara_result(creating) + ) + with ( + patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ), + patch( + "agentrun.super_agent.client.AgentRuntimeControlAPI", + return_value=ctrl, + ), + ): + client = SuperAgentClient(config=_client_config()) + agent = await client.create_async(name="alpha") + assert agent.name == "alpha" + rt_client.get_async.assert_awaited() # _wait_final_async 确实调了 get + + +# ─── sync list / list_all 未覆盖分支 ───────────────────────── + + +def test_list_sync_filters_non_super_agent(): + items = [ + _fake_rt(name="a"), + _fake_rt(name="b", protocol_type="HTTP"), + _fake_rt(name="c"), + ] + rt_client = MagicMock() + rt_client.list = MagicMock(return_value=items) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = client.list() + assert [a.name for a in result] == ["a", "c"] + + +def test_list_all_sync_multi_page(): + page1 = [_fake_rt(name=f"a{i}") for i in range(50)] + page2 = [_fake_rt(name=f"a{i}") for i in range(50, 85)] + pages = [page1, page2] + rt_client = MagicMock() + rt_client.list = MagicMock(side_effect=pages) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = client.list_all() + assert len(result) == 85 + + +def test_list_all_async_empty_first_page_breaks(): + """list_async 首页直接空, list_all 立刻 break.""" + + async def _list(*args, **kwargs): + return [] + + rt_client = MagicMock() + rt_client.list_async = _list + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = asyncio.run(client.list_all_async()) + assert result == [] + + +def test_list_all_sync_empty_first_page_breaks(): + rt_client = MagicMock() + rt_client.list = MagicMock(return_value=[]) + with patch( + "agentrun.super_agent.client.AgentRuntimeClient", + return_value=rt_client, + ): + client = SuperAgentClient(config=_client_config()) + result = client.list_all() + assert result == [] + + +def test_no_agent_runtime_in_public_signatures(): + """No public SuperAgentClient method exposes AgentRuntime-related types.""" + public_methods = [m for m in dir(SuperAgentClient) if not m.startswith("_")] + for name in public_methods: + attr = getattr(SuperAgentClient, name) + if not callable(attr): + continue + sig = inspect.signature(attr) + all_annotations = [p.annotation for p in sig.parameters.values()] + all_annotations.append(sig.return_annotation) + rendered = " ".join(str(a) for a in all_annotations) + assert ( + "AgentRuntime" not in rendered + ), f"{name} exposes AgentRuntime in its signature: {rendered}" diff --git a/tests/unittests/super_agent/test_control.py b/tests/unittests/super_agent/test_control.py new file mode 100644 index 0000000..33d36a1 --- /dev/null +++ b/tests/unittests/super_agent/test_control.py @@ -0,0 +1,429 @@ +"""Unit tests for ``agentrun.super_agent.api.control``.""" + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from agentrun.super_agent.api.control import ( + _add_ram_prefix_to_host, + API_VERSION, + build_super_agent_endpoint, + ensure_super_agent_patches_applied, + from_agent_runtime, + is_super_agent, + parse_super_agent_config, + SUPER_AGENT_PROTOCOL_TYPE, + SUPER_AGENT_RESOURCE_PATH, + SUPER_AGENT_TAG, + to_create_input, + to_update_input, +) +from agentrun.super_agent.api.data import SuperAgentDataAPI +from agentrun.utils.config import Config + +# 本文件部分测试依赖 Dara ProtocolConfiguration 已被打过补丁 (externalEndpoint), +# 显式在模块加载时触发补丁 (幂等, 与 SuperAgentClient.__init__ 内触发点一致)。 +ensure_super_agent_patches_applied() + +# ─── build_super_agent_endpoint ──────────────────────────────── + + +def test_build_super_agent_endpoint_production(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + ep = build_super_agent_endpoint(cfg) + assert ( + ep + == "https://123-ram.agentrun-data.cn-hangzhou.aliyuncs.com/super-agents/__SUPER_AGENT__" + ) + + +def test_build_super_agent_endpoint_pre_environment(): + cfg = Config( + data_endpoint=( + "http://1431999136518149.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + ) + ep = build_super_agent_endpoint(cfg) + assert ( + ep + == "http://1431999136518149-ram.funagent-data-pre.cn-hangzhou.aliyuncs.com/super-agents/__SUPER_AGENT__" + ) + + +def test_build_super_agent_endpoint_custom_gateway(): + cfg = Config(data_endpoint="https://my-gateway.example.com") + ep = build_super_agent_endpoint(cfg) + assert ep == "https://my-gateway.example.com/super-agents/__SUPER_AGENT__" + + +def test_build_super_agent_endpoint_unknown_first_segment(): + # `agentrun-data` in the first segment → no `-ram` rewrite + cfg = Config(data_endpoint="https://agentrun-data.example.com") + ep = build_super_agent_endpoint(cfg) + assert ( + ep == "https://agentrun-data.example.com/super-agents/__SUPER_AGENT__" + ) + + +# ─── _add_ram_prefix_to_host ────────────────────────────────── + + +def test_add_ram_prefix_to_host_no_netloc(): + assert _add_ram_prefix_to_host("") == "" + assert _add_ram_prefix_to_host("/path/only") == "/path/only" + + +def test_add_ram_prefix_to_host_single_segment_host(): + # Host has a single segment → no rewrite + assert ( + _add_ram_prefix_to_host("https://localhost:8080") + == "https://localhost:8080" + ) + + +def test_add_ram_prefix_to_host_unknown_domain(): + assert ( + _add_ram_prefix_to_host("https://foo.example.com") + == "https://foo.example.com" + ) + + +# ─── SuperAgentDataAPI URL 含版本号 ────────────────────────── +def test_build_data_url_via_with_path_includes_version(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + api = SuperAgentDataAPI("demo", config=cfg) + url = api.with_path("invoke") + assert url.endswith( + f"/{API_VERSION}/super-agents/{SUPER_AGENT_RESOURCE_PATH}/invoke" + ) + + +# ─── to_create_input ────────────────────────────────────────── + + +def test_to_create_input_minimal(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_create_input("alpha", cfg=cfg) + assert inp.agent_runtime_name == "alpha" + assert inp.system_tags == [SUPER_AGENT_TAG] + pc = inp.protocol_configuration + assert pc.type == SUPER_AGENT_PROTOCOL_TYPE + assert pc.external_endpoint.endswith("/super-agents/__SUPER_AGENT__") + settings = pc.protocol_settings + assert len(settings) == 1 + cfg_dict = json.loads(settings[0]["config"]) + assert cfg_dict["path"] == "/invoke" + assert cfg_dict["headers"] == {} + forwarded = cfg_dict["body"]["forwardedProps"] + assert forwarded["agents"] == [] + assert forwarded["metadata"] == {"agentRuntimeName": "alpha"} + + +def test_to_create_input_full(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_create_input( + "bravo", + prompt="hello", + agents=["a1"], + tools=["t1", "t2"], + skills=["s1"], + sandboxes=["sb1"], + workspaces=["ws1"], + model_service_name="foo", + model_name="bar", + cfg=cfg, + ) + pc_dict = inp.model_dump()["protocolConfiguration"] + settings_cfg = json.loads(pc_dict["protocolSettings"][0]["config"]) + assert settings_cfg["path"] == "/invoke" + assert settings_cfg["headers"] == {} + forwarded = settings_cfg["body"]["forwardedProps"] + assert forwarded["prompt"] == "hello" + assert forwarded["agents"] == ["a1"] + assert forwarded["tools"] == ["t1", "t2"] + assert forwarded["skills"] == ["s1"] + assert forwarded["sandboxes"] == ["sb1"] + assert forwarded["workspaces"] == ["ws1"] + assert forwarded["modelServiceName"] == "foo" + assert forwarded["modelName"] == "bar" + + +def test_to_create_input_system_tags_fixed(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_create_input("c", cfg=cfg) + assert inp.system_tags == [SUPER_AGENT_TAG] + + +def test_to_create_input_metadata_only_agent_runtime_name(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_create_input("d", cfg=cfg) + settings_cfg = json.loads( + inp.protocol_configuration.protocol_settings[0]["config"] + ) + assert settings_cfg["body"]["forwardedProps"]["metadata"] == { + "agentRuntimeName": "d" + } + + +def test_to_create_input_uses_pre_environment_endpoint(): + cfg = Config( + data_endpoint=( + "http://1431999136518149.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + ) + inp = to_create_input("pre-agent", cfg=cfg) + ep = inp.protocol_configuration.external_endpoint + assert "funagent-data-pre" in ep + assert "-ram" in ep + + +# ─── is_super_agent / parse_super_agent_config ───────────────── + + +def _make_rt(**kwargs): + """Minimal fake AgentRuntime-like object.""" + defaults = { + "agent_runtime_name": "n", + "agent_runtime_id": "rid", + "agent_runtime_arn": "arn", + "status": "READY", + "created_at": "2026-01-01", + "last_updated_at": "2026-01-02", + "description": None, + "protocol_configuration": None, + } + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + +def test_from_agent_runtime(): + config_json = json.dumps({ + "path": "/invoke", + "headers": {}, + "body": { + "forwardedProps": { + "prompt": "hi", + "agents": ["a"], + "tools": ["t"], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": "svc", + "modelName": "mod", + "metadata": {"agentRuntimeName": "foo"}, + } + }, + }) + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "config": config_json, + "name": "foo", + "path": "/invoke", + }], + "externalEndpoint": "https://x.com/super-agents/__SUPER_AGENT__", + } + rt = _make_rt(agent_runtime_name="foo", protocol_configuration=pc) + agent = from_agent_runtime(rt) + assert agent.name == "foo" + assert agent.prompt == "hi" + assert agent.agents == ["a"] + assert agent.tools == ["t"] + assert agent.model_service_name == "svc" + assert agent.model_name == "mod" + assert ( + agent.external_endpoint == "https://x.com/super-agents/__SUPER_AGENT__" + ) + + +def test_from_agent_runtime_legacy_flat_config(): + """旧结构兼容: config 是扁平 dict, 业务字段直接在根 (历史 AgentRuntime).""" + config_json = json.dumps({ + "prompt": "legacy", + "agents": ["la"], + "tools": ["lt"], + "skills": [], + "sandboxes": [], + "workspaces": [], + "modelServiceName": "legacy-svc", + "modelName": "legacy-mod", + "metadata": {"agentRuntimeName": "legacy"}, + }) + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "config": config_json, + "name": "legacy", + "path": "/invoke", + }], + "externalEndpoint": "https://x.com/super-agents/__SUPER_AGENT__", + } + rt = _make_rt(agent_runtime_name="legacy", protocol_configuration=pc) + agent = from_agent_runtime(rt) + assert agent.prompt == "legacy" + assert agent.agents == ["la"] + assert agent.model_service_name == "legacy-svc" + + +def test_parse_super_agent_config_dict_config_new_structure(): + """config 已经是 dict (非字符串) 时也能拍平.""" + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "config": { + "path": "/invoke", + "headers": {}, + "body": { + "forwardedProps": { + "prompt": "p", + "agents": [], + } + }, + }, + }], + } + business = parse_super_agent_config(_make_rt(protocol_configuration=pc)) + assert business["prompt"] == "p" + assert business["agents"] == [] + + +def test_parse_super_agent_config_dict_config_legacy_flat(): + """config 是 dict + 旧扁平结构时走 fallback, 原样返回.""" + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{ + "type": SUPER_AGENT_PROTOCOL_TYPE, + "config": {"prompt": "legacy-dict", "agents": ["la"]}, + }], + } + business = parse_super_agent_config(_make_rt(protocol_configuration=pc)) + assert business == {"prompt": "legacy-dict", "agents": ["la"]} + + +def test_flatten_protocol_config_non_dict_returns_empty(): + """非 dict 输入 (防御分支) 返回空 dict.""" + from agentrun.super_agent.api.control import _flatten_protocol_config + + assert _flatten_protocol_config(None) == {} + assert _flatten_protocol_config("not-a-dict") == {} + assert _flatten_protocol_config([1, 2, 3]) == {} + + +def test_flatten_protocol_config_body_without_forwarded_props(): + """body 存在但缺 forwardedProps → fallback 到整个 cfg (旧结构).""" + from agentrun.super_agent.api.control import _flatten_protocol_config + + cfg = {"body": {"other": "x"}, "prompt": "flat"} + assert _flatten_protocol_config(cfg) == cfg + + +def test_is_super_agent_true(): + pc = { + "type": SUPER_AGENT_PROTOCOL_TYPE, + "protocolSettings": [{"type": SUPER_AGENT_PROTOCOL_TYPE}], + } + assert is_super_agent(_make_rt(protocol_configuration=pc)) + + +def test_is_super_agent_false(): + for type_name in ("HTTP", "MCP", "OTHER"): + pc = {"type": type_name, "protocolSettings": [{"type": type_name}]} + assert not is_super_agent(_make_rt(protocol_configuration=pc)) + # No protocol_configuration + assert not is_super_agent(_make_rt(protocol_configuration=None)) + + +def test_parse_super_agent_config_invalid_json_returns_empty(): + pc = { + "protocolSettings": [ + {"type": SUPER_AGENT_PROTOCOL_TYPE, "config": "not-json"} + ] + } + assert parse_super_agent_config(_make_rt(protocol_configuration=pc)) == {} + + +def test_parse_super_agent_config_missing_config_returns_empty(): + pc = {"protocolSettings": [{"type": SUPER_AGENT_PROTOCOL_TYPE}]} + assert parse_super_agent_config(_make_rt(protocol_configuration=pc)) == {} + + +# ─── to_update_input ────────────────────────────────────────── + + +def test_to_update_input_full_protocol_replace(): + cfg = Config(account_id="123", region_id="cn-hangzhou") + inp = to_update_input( + "alpha", + { + "description": "new", + "prompt": "p", + "agents": [], + "tools": ["t"], + "skills": [], + "sandboxes": [], + "workspaces": [], + "model_service_name": None, + "model_name": None, + }, + cfg=cfg, + ) + assert inp.description == "new" + settings = inp.protocol_configuration.protocol_settings + assert len(settings) == 1 + cfg_json = json.loads(settings[0]["config"]) + forwarded = cfg_json["body"]["forwardedProps"] + assert forwarded["metadata"]["agentRuntimeName"] == "alpha" + assert forwarded["prompt"] == "p" + assert forwarded["tools"] == ["t"] + + +# ─── Dara ListAgentRuntimesRequest systemTags 原生字段 ────────────── +# ``systemTags`` 已由 Dara SDK 原生支持, 无需补丁。以下测试只校验 pydantic → +# Dara roundtrip 能把 ``system_tags`` 保留到请求 query。 + + +def test_list_request_from_map_preserves_system_tags(): + from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest + + req = ListAgentRuntimesRequest().from_map({ + "systemTags": SUPER_AGENT_TAG, + "pageNumber": 1, + "pageSize": 20, + }) + assert req.system_tags == SUPER_AGENT_TAG + + +def test_list_request_to_map_preserves_system_tags(): + from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest + + req = ListAgentRuntimesRequest() + req.system_tags = SUPER_AGENT_TAG + assert req.to_map().get("systemTags") == SUPER_AGENT_TAG + + +def test_list_with_options_writes_system_tags_query(): + from alibabacloud_agentrun20250910.client import Client as _DaraClient + from alibabacloud_agentrun20250910.models import ListAgentRuntimesRequest + from darabonba.runtime import RuntimeOptions + + captured = {} + + def _fake_call_api(self, params, req, rt): + captured["query"] = dict(req.query) if req.query else {} + raise RuntimeError("_stop_") + + client = _DaraClient.__new__(_DaraClient) + client._endpoint = "x" + client.call_api = _fake_call_api.__get__(client, _DaraClient) + + req = ListAgentRuntimesRequest( + page_number=1, page_size=20, system_tags=SUPER_AGENT_TAG + ) + with pytest.raises(RuntimeError, match="_stop_"): + client.list_agent_runtimes_with_options(req, {}, RuntimeOptions()) + assert captured["query"].get("systemTags") == SUPER_AGENT_TAG diff --git a/tests/unittests/super_agent/test_data_api.py b/tests/unittests/super_agent/test_data_api.py new file mode 100644 index 0000000..ec2d7e2 --- /dev/null +++ b/tests/unittests/super_agent/test_data_api.py @@ -0,0 +1,667 @@ +"""Unit tests for ``agentrun.super_agent.api.data.SuperAgentDataAPI``.""" + +import json +import re + +import httpx +import pytest +import respx + +from agentrun.super_agent.api.data import SuperAgentDataAPI +from agentrun.utils.config import Config + + +def _auth_cfg(**overrides) -> Config: + """Config with RAM AK/SK so ``DataAPI.auth`` actually signs.""" + base = dict( + access_key_id="AK", + access_key_secret="SK", + account_id="123", + region_id="cn-hangzhou", + ) + base.update(overrides) + return Config(**base) + + +# ─── URL construction (production / pre / custom gateway) ──── + + +@respx.mock +async def test_invoke_async_phase1_url_includes_version_production(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("agent-prod", config=cfg) + route = respx.post( + re.compile( + r"https://123-ram\.agentrun-data\.cn-hangzhou\.aliyuncs\.com" + r"/2025-09-10/super-agents/__SUPER_AGENT__/invoke" + ) + ).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversationId": "c1", + "url": "https://x/stream", + "headers": {}, + } + }, + ) + ) + await api.invoke_async([{"role": "user", "content": "hi"}]) + assert route.called + + +@respx.mock +async def test_invoke_async_phase1_url_pre_environment(): + cfg = _auth_cfg( + data_endpoint=( + "http://1431999136518149.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + ) + api = SuperAgentDataAPI("agent-pre", config=cfg) + route = respx.post( + re.compile( + r"http://1431999136518149-ram\.funagent-data-pre\.cn-hangzhou\.aliyuncs\.com" + r"/2025-09-10/super-agents/__SUPER_AGENT__/invoke" + ) + ).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://stream", + "headers": {}, + } + }, + ) + ) + await api.invoke_async([{"role": "user", "content": "hi"}]) + assert route.called + + +@respx.mock +async def test_invoke_async_phase1_url_custom_gateway_no_ram(): + cfg = _auth_cfg(data_endpoint="https://my-gateway.example.com") + api = SuperAgentDataAPI("agent-cust", config=cfg) + route = respx.post( + "https://my-gateway.example.com/2025-09-10/super-agents/__SUPER_AGENT__/invoke" + ).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + ) + await api.invoke_async([{"role": "user", "content": "hi"}]) + assert route.called + + +@respx.mock +async def test_list_conversations_async_url_pre_environment(): + cfg = _auth_cfg( + data_endpoint="http://111.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + api = SuperAgentDataAPI("n", config=cfg) + route = respx.get( + re.compile( + r"http://111-ram\.funagent-data-pre\.cn-hangzhou\.aliyuncs\.com" + r"/2025-09-10/super-agents/__SUPER_AGENT__/conversations(\?.*)?$" + ) + ).mock( + return_value=httpx.Response(200, json={"data": {"conversations": []}}) + ) + await api.list_conversations_async() + assert route.called + + +@respx.mock +async def test_get_conversation_async_url_pre_environment(): + cfg = _auth_cfg( + data_endpoint="http://111.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + api = SuperAgentDataAPI("n", config=cfg) + route = respx.get( + "http://111-ram.funagent-data-pre.cn-hangzhou.aliyuncs.com" + "/2025-09-10/super-agents/__SUPER_AGENT__/conversations/cid" + ).mock(return_value=httpx.Response(200, json={"data": {}})) + await api.get_conversation_async("cid") + assert route.called + + +@respx.mock +async def test_delete_conversation_async_url_pre_environment(): + cfg = _auth_cfg( + data_endpoint="http://111.funagent-data-pre.cn-hangzhou.aliyuncs.com" + ) + api = SuperAgentDataAPI("n", config=cfg) + route = respx.delete( + "http://111-ram.funagent-data-pre.cn-hangzhou.aliyuncs.com" + "/2025-09-10/super-agents/__SUPER_AGENT__/conversations/cid" + ).mock(return_value=httpx.Response(200)) + await api.delete_conversation_async("cid") + assert route.called + + +# ─── body shape ─────────────────────────────────────────────── + + +@respx.mock +async def test_invoke_async_body_new_conversation(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["body"] = json.loads(request.content) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async([{"role": "user", "content": "hi"}]) + assert captured["body"]["messages"] == [{"role": "user", "content": "hi"}] + assert captured["body"]["forwardedProps"]["metadata"] == { + "agentRuntimeName": "demo" + } + assert "conversationId" not in captured["body"]["forwardedProps"] + + +@respx.mock +async def test_invoke_async_body_continue_conversation(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["body"] = json.loads(request.content) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "abc", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async( + [{"role": "user", "content": "hi"}], conversation_id="abc" + ) + assert captured["body"]["forwardedProps"]["conversationId"] == "abc" + + +@respx.mock +async def test_invoke_async_body_forwarded_extras_passthrough(): + """``forwarded_extras`` 业务字段 MUST 出现在 forwardedProps 顶层.""" + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["body"] = json.loads(request.content) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async( + [{"role": "user", "content": "hi"}], + forwarded_extras={ + "prompt": "p", + "agents": ["a1"], + "modelServiceName": "svc", + "modelName": "mod", + }, + ) + fp = captured["body"]["forwardedProps"] + assert fp["prompt"] == "p" + assert fp["agents"] == ["a1"] + assert fp["modelServiceName"] == "svc" + assert fp["modelName"] == "mod" + # SDK 托管字段不受 extras 影响 + assert fp["metadata"] == {"agentRuntimeName": "demo"} + + +@respx.mock +async def test_invoke_async_body_extras_cannot_override_sdk_fields(): + """extras 里带 metadata/conversationId 也不能覆盖 SDK 托管字段.""" + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["body"] = json.loads(request.content) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "real", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async( + [{"role": "user", "content": "hi"}], + conversation_id="real", + forwarded_extras={ + "metadata": {"agentRuntimeName": "SPOOFED"}, + "conversationId": "SPOOFED", + "prompt": "p", + }, + ) + fp = captured["body"]["forwardedProps"] + assert fp["metadata"] == {"agentRuntimeName": "demo"} + assert fp["conversationId"] == "real" + assert fp["prompt"] == "p" + + +# ─── signing ────────────────────────────────────────────────── + + +@respx.mock +async def test_invoke_async_request_signed(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async([{"role": "user"}]) + h = captured["headers"] + assert any(k.lower() == "agentrun-authorization" for k in h) + assert any(k.lower() == "x-acs-date" for k in h) + assert any(k.lower() == "x-acs-content-sha256" for k in h) + ct = next((v for k, v in h.items() if k.lower() == "content-type"), None) + assert ct == "application/json" + + +# ─── returns InvokeResponseData ────────────────────────────── + + +@respx.mock +async def test_invoke_async_returns_invoke_response_data(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.post(re.compile(r".*/invoke$")).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://stream", + "headers": {"X-Super-Agent-Session-Id": "s"}, + } + }, + ) + ) + resp = await api.invoke_async([]) + assert resp.conversation_id == "c" + assert resp.stream_url == "https://stream" + assert resp.stream_headers == {"X-Super-Agent-Session-Id": "s"} + + +@respx.mock +async def test_invoke_async_missing_conversation_id_raises(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.post(re.compile(r".*/invoke$")).mock( + return_value=httpx.Response( + 200, json={"data": {"url": "https://s", "headers": {}}} + ) + ) + with pytest.raises(ValueError) as exc: + await api.invoke_async([]) + assert "missing" in str(exc.value) + assert "conversationId" in str(exc.value) + + +@respx.mock +async def test_invoke_async_5xx_raises_http_error(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.post(re.compile(r".*/invoke$")).mock( + return_value=httpx.Response(500, text="boom") + ) + with pytest.raises(httpx.HTTPStatusError): + await api.invoke_async([]) + + +@respx.mock +async def test_invoke_async_user_headers_merged(): + cfg = _auth_cfg(headers={"X-Custom": "v"}) + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response( + 200, + json={ + "data": { + "conversationId": "c", + "url": "https://s", + "headers": {}, + } + }, + ) + + respx.post(re.compile(r".*/invoke$")).mock(side_effect=_responder) + await api.invoke_async([]) + assert captured["headers"].get("x-custom") == "v" + assert any( + k.lower() == "agentrun-authorization" for k in captured["headers"] + ) + + +# ─── stream_async ──────────────────────────────────────────── + + +@respx.mock +async def test_stream_async_yields_sse_events(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + sse_body = b"event: m\ndata: hello\n\nevent: m\ndata: world\n\n" + respx.get("https://stream.example.com/flow").mock( + return_value=httpx.Response(200, content=sse_body) + ) + events = [] + async for ev in api.stream_async("https://stream.example.com/flow"): + events.append(ev) + assert [(e.event, e.data) for e in events] == [ + ("m", "hello"), + ("m", "world"), + ] + + +@respx.mock +async def test_stream_async_request_includes_phase1_headers_and_signed(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response(200, content=b":keep\n\n") + + respx.get("https://stream.example.com/go").mock(side_effect=_responder) + it = api.stream_async( + "https://stream.example.com/go", + stream_headers={"X-Super-Agent-Session-Id": "sess1"}, + ) + async for _ in it: + pass + h = captured["headers"] + assert h.get("x-super-agent-session-id") == "sess1" + assert any(k.lower() == "agentrun-authorization" for k in h) + + +# ─── get/delete conversation ───────────────────────────────── + + +@respx.mock +async def test_get_conversation_async_url_and_signed(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"data": {"conversationId": "c1"}}) + + respx.get( + re.compile( + r".*/2025-09-10/super-agents/__SUPER_AGENT__/conversations/c1$" + ) + ).mock(side_effect=_responder) + result = await api.get_conversation_async("c1") + assert result == {"conversationId": "c1"} + assert any( + k.lower() == "agentrun-authorization" for k in captured["headers"] + ) + + +@respx.mock +async def test_get_conversation_async_returns_empty_on_missing_data(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations/.*")).mock( + return_value=httpx.Response(200, json={"code": "ok"}) + ) + assert await api.get_conversation_async("c1") == {} + + +@respx.mock +async def test_delete_conversation_async_returns_none(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + route = respx.delete(re.compile(r".*/conversations/c1$")).mock( + return_value=httpx.Response(200) + ) + result = await api.delete_conversation_async("c1") + assert result is None + assert route.called + + +@respx.mock +async def test_delete_conversation_async_404_raises(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.delete(re.compile(r".*/conversations/missing$")).mock( + return_value=httpx.Response(404) + ) + with pytest.raises(httpx.HTTPStatusError): + await api.delete_conversation_async("missing") + + +# ─── list_conversations ────────────────────────────────────── + + +@respx.mock +async def test_list_conversations_async_parses_data_array(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations(\?.*)?$")).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversations": [ + {"conversationId": "c1", "title": "t1"}, + {"conversationId": "c2", "title": "t2"}, + ] + }, + "success": True, + }, + ) + ) + result = await api.list_conversations_async() + assert [c["conversationId"] for c in result] == ["c1", "c2"] + + +@respx.mock +async def test_list_conversations_async_metadata_query_encoded(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["url"] = str(request.url) + return httpx.Response(200, json={"data": {"conversations": []}}) + + respx.get(re.compile(r".*/conversations.*")).mock(side_effect=_responder) + await api.list_conversations_async(metadata={"agentRuntimeName": "demo"}) + assert "metadata=" in captured["url"] + # metadata is json-encoded then URL-encoded + assert "agentRuntimeName" in captured["url"] + + +@respx.mock +async def test_list_conversations_async_without_metadata_no_query(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["url"] = str(request.url) + return httpx.Response(200, json={"data": {"conversations": []}}) + + respx.get(re.compile(r".*/conversations.*")).mock(side_effect=_responder) + await api.list_conversations_async() + assert "metadata" not in captured["url"] + + +@respx.mock +async def test_list_conversations_async_request_signed(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + captured = {} + + def _responder(request): + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"data": {"conversations": []}}) + + respx.get(re.compile(r".*/conversations.*")).mock(side_effect=_responder) + await api.list_conversations_async() + assert any( + k.lower() == "agentrun-authorization" for k in captured["headers"] + ) + + +@respx.mock +async def test_list_conversations_async_missing_data_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(200, json={"success": True}) + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_data_not_dict_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(200, json={"data": "unexpected"}) + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_conversations_not_list_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response( + 200, json={"data": {"conversations": "bad"}} + ) + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_filters_non_dict_items(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response( + 200, + json={ + "data": { + "conversations": [ + {"conversationId": "c1"}, + "invalid", + None, + {"conversationId": "c2"}, + ] + } + }, + ) + ) + result = await api.list_conversations_async() + assert [c["conversationId"] for c in result] == ["c1", "c2"] + + +@respx.mock +async def test_list_conversations_async_payload_not_dict_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(200, json=[1, 2, 3]) + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_empty_body_returns_empty(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(200, content=b"") + ) + assert await api.list_conversations_async() == [] + + +@respx.mock +async def test_list_conversations_async_5xx_raises(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + respx.get(re.compile(r".*/conversations.*")).mock( + return_value=httpx.Response(500, text="boom") + ) + with pytest.raises(httpx.HTTPStatusError): + await api.list_conversations_async() + + +# ─── sync stubs are NotImplementedError ────────────────────── + + +def test_sync_methods_not_implemented(): + cfg = _auth_cfg() + api = SuperAgentDataAPI("demo", config=cfg) + for fn in ( + lambda: api.invoke([]), + lambda: api.stream("url"), + lambda: api.list_conversations(), + lambda: api.get_conversation("c"), + lambda: api.delete_conversation("c"), + ): + with pytest.raises(NotImplementedError): + fn() diff --git a/tests/unittests/super_agent/test_no_coupling.py b/tests/unittests/super_agent/test_no_coupling.py new file mode 100644 index 0000000..02fa4c8 --- /dev/null +++ b/tests/unittests/super_agent/test_no_coupling.py @@ -0,0 +1,59 @@ +"""Verify the super_agent module does not depend on conversation_service.""" + +import ast +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +import agentrun.super_agent as super_agent_pkg + + +def _python_files_under(path: Path): + return [p for p in path.rglob("*.py") if "__pycache__" not in p.parts] + + +def test_no_import_conversation_service(): + base = Path(super_agent_pkg.__file__).parent + offenders = [] + for file in _python_files_under(base): + tree = ast.parse(file.read_text()) + for node in ast.walk(tree): + if isinstance(node, ast.ImportFrom) and node.module: + if node.module.startswith("agentrun.conversation_service"): + offenders.append(str(file.relative_to(base))) + elif isinstance(node, ast.Import): + for alias in node.names: + if alias.name.startswith("agentrun.conversation_service"): + offenders.append(str(file.relative_to(base))) + assert not offenders, ( + "super_agent must not import conversation_service, found in:" + f" {offenders}" + ) + + +async def test_get_conversation_does_not_touch_session_store(): + """Calling agent.get_conversation_async must not touch SessionStore.""" + from agentrun.super_agent.agent import SuperAgent + + session_store_mock = MagicMock() + # Patch in case code accidentally imported it + with patch( + "agentrun.conversation_service.SessionStore", + session_store_mock, + create=True, + ): + # Stub the data-plane call so we don't hit the network + with patch( + "agentrun.super_agent.agent.SuperAgentDataAPI" + ) as data_api_factory: + + async def _noop(*args, **kwargs): + return {} + + data_api_factory.return_value.get_conversation_async = _noop + agent = SuperAgent(name="demo") + await agent.get_conversation_async("c1") + + # SessionStore should never have been called/instantiated + session_store_mock.assert_not_called() diff --git a/tests/unittests/super_agent/test_stream.py b/tests/unittests/super_agent/test_stream.py new file mode 100644 index 0000000..54c9590 --- /dev/null +++ b/tests/unittests/super_agent/test_stream.py @@ -0,0 +1,203 @@ +"""Unit tests for ``agentrun.super_agent.stream``.""" + +from typing import List +from unittest.mock import MagicMock + +import pytest + +from agentrun.super_agent.stream import InvokeStream, parse_sse_async, SSEEvent + + +class _FakeResponse: + """Replays a pre-canned list of SSE lines via ``aiter_lines``.""" + + def __init__(self, lines: List[str]): + self._lines = lines + + async def aiter_lines(self): + for line in self._lines: + yield line + + +async def _collect(lines: List[str]) -> List[SSEEvent]: + return [ev async for ev in parse_sse_async(_FakeResponse(lines))] + + +# ─── parse_sse_async ───────────────────────────────────────── + + +async def test_parse_sse_simple_event(): + events = await _collect(["event: m", "data: hi", "id: 1", ""]) + assert len(events) == 1 + ev = events[0] + assert ev.event == "m" + assert ev.data == "hi" + assert ev.id == "1" + + +async def test_parse_sse_multiline_data(): + events = await _collect(["data: a", "data: b", ""]) + assert len(events) == 1 + assert events[0].data == "a\nb" + + +async def test_parse_sse_comment_ignored(): + events = await _collect([": comment", "data: x", ""]) + assert len(events) == 1 + assert events[0].data == "x" + + +async def test_parse_sse_unknown_field_ignored(): + events = await _collect(["unknown: v", "data: x", ""]) + assert len(events) == 1 + assert events[0].data == "x" + + +async def test_parse_sse_retry_invalid_ignored(): + events = await _collect(["retry: not-a-number", "data: x", ""]) + assert len(events) == 1 + assert events[0].retry is None + + +async def test_parse_sse_retry_valid(): + events = await _collect(["retry: 5000", "data: x", ""]) + assert events[0].retry == 5000 + + +async def test_parse_sse_strip_leading_space_after_colon(): + events = await _collect(["data: hello", ""]) + assert events[0].data == "hello" + + +async def test_parse_sse_field_without_colon(): + events = await _collect(["data", ""]) + assert len(events) == 1 + assert events[0].data == "" + + +async def test_parse_sse_flush_at_stream_end(): + events = await _collect(["data: final"]) + assert len(events) == 1 + assert events[0].data == "final" + + +async def test_parse_sse_multiple_events(): + events = await _collect([ + "event: a", + "data: 1", + "", + "event: b", + "data: 2", + "", + ]) + assert len(events) == 2 + assert events[0].event == "a" + assert events[1].event == "b" + + +async def test_parse_sse_empty_line_without_content_skipped(): + # Two consecutive empty lines → no duplicate events + events = await _collect(["", "data: x", "", ""]) + assert len(events) == 1 + + +# ─── SSEEvent.data_json ────────────────────────────────────── + + +def test_sse_event_data_json_success(): + ev = SSEEvent(event="x", data='{"k":1}') + assert ev.data_json() == {"k": 1} + + +def test_sse_event_data_json_failure(): + assert SSEEvent(event="x", data="not json").data_json() is None + + +def test_sse_event_data_json_empty(): + assert SSEEvent(event="x", data="").data_json() is None + + +# ─── InvokeStream ──────────────────────────────────────────── + + +async def _make_stream(events: List[SSEEvent]) -> InvokeStream: + async def _gen(): + for ev in events: + yield ev + + async def _factory(): + return _gen() + + return InvokeStream( + conversation_id="c1", + session_id="s1", + stream_url="https://x.com/stream", + stream_headers={"X-Super-Agent-Session-Id": "s1"}, + _stream_factory=_factory, + ) + + +async def test_invoke_stream_async_iter(): + events = [ + SSEEvent(event="m", data="1"), + SSEEvent(event="m", data="2"), + SSEEvent(event="m", data="3"), + ] + stream = await _make_stream(events) + collected = [ev async for ev in stream] + assert len(collected) == 3 + assert [ev.data for ev in collected] == ["1", "2", "3"] + + +async def test_invoke_stream_aclose(): + closed = {"v": False} + + async def _gen(): + try: + yield SSEEvent(event="m", data="x") + finally: + closed["v"] = True + + async def _factory(): + return _gen() + + stream = InvokeStream( + conversation_id="c", + session_id="s", + stream_url="u", + stream_headers={}, + _stream_factory=_factory, + ) + # Advance one step to open the iterator + it = stream.__aiter__() + await it.__anext__() + await stream.aclose() + assert closed["v"] is True + # After close, iteration is terminated + with pytest.raises(StopAsyncIteration): + await it.__anext__() + + +async def test_invoke_stream_async_with(): + closed = {"v": False} + + async def _gen(): + try: + yield SSEEvent(event="m", data="x") + finally: + closed["v"] = True + + async def _factory(): + return _gen() + + stream = InvokeStream( + conversation_id="c", + session_id="s", + stream_url="u", + stream_headers={}, + _stream_factory=_factory, + ) + async with stream as s: + async for _ev in s: + break + assert closed["v"] is True diff --git a/tests/unittests/toolset/api/test_openapi.py b/tests/unittests/toolset/api/test_openapi.py index 0f2e82d..bb32eac 100644 --- a/tests/unittests/toolset/api/test_openapi.py +++ b/tests/unittests/toolset/api/test_openapi.py @@ -548,9 +548,7 @@ def test_post_with_ref_schema(self): "content": { "application/json": { "schema": { - "$ref": ( - "#/components/schemas/CreateOrderRequest" - ) + "$ref": "#/components/schemas/CreateOrderRequest" } } }, @@ -761,9 +759,7 @@ def test_invalid_ref_gracefully_handled(self): "content": { "application/json": { "schema": { - "$ref": ( - "#/components/schemas/NonExistent" - ) + "$ref": "#/components/schemas/NonExistent" } } } @@ -796,9 +792,7 @@ def test_external_ref_not_resolved(self): "content": { "application/json": { "schema": { - "$ref": ( - "https://example.com/schemas/external.json" - ) + "$ref": "https://example.com/schemas/external.json" } } } @@ -918,9 +912,7 @@ def _get_coffee_shop_schema(): "content": { "application/json": { "schema": { - "$ref": ( - "#/components/schemas/CreateOrderRequest" - ) + "$ref": "#/components/schemas/CreateOrderRequest" } } }, @@ -956,9 +948,7 @@ def _get_coffee_shop_schema(): "content": { "application/json": { "schema": { - "$ref": ( - "#/components/schemas/UpdateOrderStatusRequest" - ) + "$ref": "#/components/schemas/UpdateOrderStatusRequest" } } }, @@ -1229,9 +1219,7 @@ def test_tool_schema(self): "openapi": "3.0.1", "info": {"title": "Test", "version": "1.0"}, "servers": [{ - "url": ( - "https://1431999136518149.agentrun-data.cn-hangzhou.aliyuncs.com/tools/test/" - ) + "url": "https://1431999136518149.agentrun-data.cn-hangzhou.aliyuncs.com/tools/test/" }], "paths": { "/invoke": {