Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 174 additions & 45 deletions packages/sdk-py/src/agent_relay/communicate/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import suppress
from inspect import isawaitable
from typing import Any
from urllib.parse import quote

try:
import aiohttp
Expand All @@ -32,7 +33,12 @@


class RelayTransport:
"""Minimal Relaycast transport backed by aiohttp."""
"""Minimal Relaycast transport backed by aiohttp.

Auth model: the workspace API key is used for admin operations (registering
agents, listing agents); the per-agent token is used for everything an
agent does (post, reply, dm, websocket).
"""

def __init__(self, agent_name: str, config: RelayConfig) -> None:
self.agent_name = agent_name
Expand Down Expand Up @@ -88,21 +94,38 @@ async def send_http(
path: str,
*,
payload: dict[str, Any] | None = None,
as_agent: bool = False,
retry: bool = True,
) -> Any:
"""Send a request and return the unwrapped ``data`` field on success.

Set ``as_agent=True`` to authenticate with the per-agent token
(required for any operation that posts as the agent). Set
``retry=False`` for best-effort calls (cleanup paths) where waiting
out the exponential backoff would block shutdown.
"""
self._require_config()
session = await self._ensure_session()
url = f"{self._base_url()}{path}"
headers = {"Authorization": f"Bearer {self.config.api_key}"}

for attempt in range(1, HTTP_RETRY_ATTEMPTS + 1):
if as_agent:
if not self.token:
raise RelayConnectionError(401, "Agent not registered; no token available.")
bearer = self.token
else:
bearer = self.config.api_key
headers = {"Authorization": f"Bearer {bearer}"}

max_attempts = HTTP_RETRY_ATTEMPTS if retry else 1
for attempt in range(1, max_attempts + 1):
try:
async with session.request(method, url, json=payload, headers=headers) as response:
if response.status == 401:
raise RelayAuthError(await self._error_message(response))

if 500 <= response.status <= 599:
message = await self._error_message(response)
if attempt < HTTP_RETRY_ATTEMPTS:
if attempt < max_attempts:
await asyncio.sleep(min(2 ** (attempt - 1), WS_RECONNECT_MAX_DELAY))
continue
raise RelayConnectionError(response.status, message)
Expand All @@ -117,15 +140,16 @@ async def send_http(
return None

if response.content_type == "application/json":
return await response.json()
body = await response.json()
return self._unwrap(body)

return await response.text()
except RelayAuthError:
raise
except RelayConnectionError:
raise
except aiohttp.ClientError as exc:
if attempt < HTTP_RETRY_ATTEMPTS:
if attempt < max_attempts:
await asyncio.sleep(min(2 ** (attempt - 1), WS_RECONNECT_MAX_DELAY))
continue
raise RelayConnectionError(0, str(exc)) from exc
Expand All @@ -141,61 +165,106 @@ async def register_agent(self) -> str:
if self.agent_id is not None and self.token is not None:
return self.agent_id

payload = await self.send_http(
data = await self.send_http(
"POST",
"/v1/agents/register",
payload={"name": self.agent_name, "workspace": self.config.workspace},
"/v1/agents",
payload={"name": self.agent_name, "type": "agent"},
)
self.agent_id = payload["agent_id"]
self.token = payload["token"]
self.agent_id = data["id"]
self.token = data["token"]
return self.agent_id

async def unregister_agent(self) -> None:
if self.agent_id is None:
await self._close_session_if_idle()
return

agent_id = self.agent_id
await self.send_http("DELETE", f"/v1/agents/{agent_id}")
await self.send_http(
"DELETE",
f"/v1/agents/{quote(self.agent_name, safe='')}",
retry=False,
)
self.agent_id = None
self.token = None
Comment on lines +182 to 188
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Clear local identity when unregister request fails

unregister_agent() now performs a best-effort DELETE with retry=False, but agent_id/token are only cleared after that call succeeds. When the API returns a 5xx (the commit message notes this can happen on delete), disconnect() suppresses the exception and leaves stale credentials cached; a later connect() reuses them via register_agent()'s early return and can get stuck with invalid auth instead of re-registering. Move the local state reset into a finally (or clear before the request) so reconnects can recover after failed cleanup.

Useful? React with 👍 / 👎.

await self._close_session_if_idle()

async def send_dm(self, recipient: str, text: str) -> str:
await self._ensure_registered()
payload = await self.send_http(
data = await self.send_http(
"POST",
"/v1/messages/dm",
payload={"to": recipient, "text": text, "from": self.agent_name},
"/v1/dm",
payload={"to": recipient, "text": text},
as_agent=True,
)
return payload["message_id"]
# The hosted API returns the DM envelope; the message id is on the
# top-level ``id`` and also on ``message.id``.
if isinstance(data, dict):
if "id" in data:
return data["id"]
inner = data.get("message")
if isinstance(inner, dict) and "id" in inner:
return inner["id"]
raise RelayConnectionError(500, "DM response missing message id")

async def post_message(self, channel: str, text: str) -> str:
await self._ensure_registered()
payload = await self.send_http(
data = await self.send_http(
"POST",
"/v1/messages/channel",
payload={"channel": channel, "text": text, "from": self.agent_name},
f"/v1/channels/{quote(channel, safe='')}/messages",
payload={"text": text},
as_agent=True,
)
return payload["message_id"]
return data["id"]

async def reply(self, message_id: str, text: str) -> str:
await self._ensure_registered()
payload = await self.send_http(
data = await self.send_http(
"POST",
"/v1/messages/reply",
payload={"message_id": message_id, "text": text, "from": self.agent_name},
f"/v1/messages/{quote(message_id, safe='')}/replies",
payload={"text": text},
as_agent=True,
)
return payload["message_id"]
return data["id"]

async def check_inbox(self) -> list[Message]:
"""Polling fallback for environments where the WebSocket cannot connect.

Prefer surfacing any deliverable messages returned by ``/v1/inbox``.
Some deployments may only expose unread metadata; in that case this
method returns an empty list instead of raising.
"""
await self._ensure_registered()
payload = await self.send_http("GET", f"/v1/inbox/{self.agent_id}")
return [self._message_from_payload(item) for item in payload.get("messages", [])]
data = await self.send_http("GET", "/v1/inbox", as_agent=True)

if not isinstance(data, dict):
return []

raw_messages = data.get("messages")
if not isinstance(raw_messages, list):
return []

messages: list[Message] = []
for item in raw_messages:
if not isinstance(item, dict):
continue
messages.append(
Message(
sender=item.get("sender") or item.get("agent_name") or item.get("from") or "unknown",
text=item.get("text") or "",
channel=item.get("channel"),
thread_id=item.get("thread_id"),
timestamp=item.get("timestamp"),
message_id=item.get("message_id") or item.get("id"),
)
)

return messages

async def list_agents(self) -> list[str]:
payload = await self.send_http("GET", "/v1/agents")
return list(payload.get("agents", []))
data = await self.send_http("GET", "/v1/agents")
if isinstance(data, list):
return [item["name"] for item in data if isinstance(item, dict) and "name" in item]
return []

async def _ensure_registered(self) -> None:
if self.agent_id is None or self.token is None:
Expand All @@ -214,6 +283,20 @@ def _require_config(self, *, require_workspace: bool = False) -> None:
def _base_url(self) -> str:
return (self.config.base_url or DEFAULT_RELAY_BASE_URL).rstrip("/")

@staticmethod
def _unwrap(body: Any) -> Any:
"""Unwrap the ``{ok, data, error}`` envelope used by the hosted API.

Mock servers that return a plain payload pass through unchanged.
"""
if isinstance(body, dict) and "ok" in body:
if not body.get("ok", False):
error = body.get("error") or {}
message = error.get("message") if isinstance(error, dict) else str(error)
raise RelayConnectionError(400, message or "Request failed")
return body.get("data")
return body

async def _ensure_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
Expand All @@ -234,12 +317,16 @@ async def _connect_websocket(self) -> None:
if self._ws is not None and not self._ws.closed:
return

from urllib.parse import quote

session = await self._ensure_session()
ws_url = f"{self._ws_base_url()}/v1/ws/{self.agent_id}?token={quote(self.token, safe='')}"
ws_url = f"{self._ws_base_url()}/v1/ws?token={quote(self.token, safe='')}"
self._ws = await session.ws_connect(ws_url)

# Subscribe to channels declared on the config so message.created
# events for those channels are delivered to this socket.
channels = list(self.config.channels or [])
if channels:
await self._ws.send_json({"type": "subscribe", "channels": channels})

def _ws_base_url(self) -> str:
base_url = self._base_url()
if base_url.startswith("https://"):
Expand Down Expand Up @@ -286,31 +373,67 @@ async def _ws_loop(self) -> None:

async def _dispatch_ws_payload(self, raw_payload: str) -> None:
payload = json.loads(raw_payload)
if payload.get("type") == "ping":
if not isinstance(payload, dict):
return

event_type = payload.get("type")
if event_type == "ping":
if self._ws is not None and not self._ws.closed:
await self._ws.send_json({"type": "pong"})
return
if payload.get("type") != "message":

message = self._message_from_event(payload)
if message is None:
return

callback = self._message_callback
if callback is None:
return

result = callback(self._message_from_payload(payload))
result = callback(message)
if isawaitable(result):
await result

@staticmethod
def _message_from_payload(payload: dict[str, Any]) -> Message:
return Message(
sender=payload["sender"],
text=payload["text"],
channel=payload.get("channel"),
thread_id=payload.get("thread_id"),
timestamp=payload.get("timestamp"),
message_id=payload.get("message_id"),
)
def _message_from_event(payload: dict[str, Any]) -> Message | None:
"""Translate a hosted WebSocket event into the SDK's flat ``Message``.

Recognises ``message.created``, ``thread.reply``, ``dm.received``,
and ``group_dm.received``. Falls back to the flat
``{type:"message", sender, text}`` shape the mock server emits.
"""
event_type = payload.get("type")

if event_type in {"message.created", "thread.reply", "dm.received", "group_dm.received"}:
inner = payload.get("message")
if not isinstance(inner, dict):
return None
sender = inner.get("agent_name") or inner.get("agent_id") or ""
text = inner.get("text", "")
channel = payload.get("channel")
thread_id = payload.get("parent_id") or inner.get("thread_id")
message_id = inner.get("id")
timestamp = inner.get("created_at") or inner.get("timestamp")
return Message(
sender=sender,
text=text,
channel=channel,
thread_id=thread_id,
timestamp=timestamp,
message_id=message_id,
)

if event_type == "message" and "sender" in payload:
return Message(
sender=payload["sender"],
text=payload.get("text", ""),
channel=payload.get("channel"),
thread_id=payload.get("thread_id"),
timestamp=payload.get("timestamp"),
message_id=payload.get("message_id"),
)

return None

@staticmethod
async def _error_message(response: aiohttp.ClientResponse) -> str:
Expand All @@ -319,7 +442,13 @@ async def _error_message(response: aiohttp.ClientResponse) -> str:
except Exception:
text = await response.text()
return text or response.reason or "Request failed"
return str(payload.get("message") or response.reason or "Request failed")
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict) and error.get("message"):
return str(error["message"])
if payload.get("message"):
return str(payload["message"])
return str(response.reason or "Request failed")


__all__ = ["RelayTransport"]
Loading
Loading