Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 119 additions & 4 deletions ax_cli/commands/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,101 @@ class MentionEvent:
created_at: str | None
space_id: str
attachments: list[dict[str, Any]] | None = None
metadata: dict[str, Any] | None = None


def _string_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, (str, int, float)):
text = str(value).strip()
return text or None
return None


def _format_shared_object(metadata: dict[str, Any] | None, *, space_id: str) -> str | None:
if not isinstance(metadata, dict):
return None
forward = metadata.get("forward")
if not isinstance(forward, dict):
return None

fields = [
("resource_type", "resource_type"),
("resource_id", "resource_id"),
("task_id", "task_id"),
("context_key", "context_key"),
("resource_uri", "resource_uri"),
("source_message_id", "source_message_id"),
("source_card_id", "source_card_id"),
("title", "title"),
]
lines = ["Shared object:"]
for label, key in fields:
value = _string_value(forward.get(key))
if value:
lines.append(f"- {label}: {value}")

summary = _string_value(forward.get("summary"))
if summary:
lines.append(f"- summary: {summary}")

task_id = _string_value(forward.get("task_id"))
context_key = _string_value(forward.get("context_key"))
if task_id or context_key:
lines.append("")
lines.append("Suggested inspection:")
if task_id:
lines.append(f"- axctl tasks get {task_id} --space-id {space_id} --json")
if context_key:
lines.append(f"- axctl context get '{context_key}' --space-id {space_id} --json")

return "\n".join(lines) if len(lines) > 1 else None


def _format_attachments(attachments: list[dict[str, Any]] | None, *, space_id: str) -> str | None:
if not attachments:
return None
lines = ["Attachments:"]
context_keys: list[str] = []
for attachment in attachments:
if not isinstance(attachment, dict):
continue
filename = _string_value(attachment.get("filename")) or "attachment"
content_type = _string_value(attachment.get("content_type"))
attachment_id = _string_value(attachment.get("id") or attachment.get("attachment_id"))
context_key = _string_value(attachment.get("context_key") or attachment.get("key"))
details = [part for part in (content_type, f"id={attachment_id}" if attachment_id else None) if part]
if context_key:
details.append(f"context_key={context_key}")
context_keys.append(context_key)
lines.append(f"- {filename}" + (f" ({', '.join(details)})" if details else ""))
if context_keys:
lines.append("")
lines.append("Suggested attachment inspection:")
for key in context_keys:
lines.append(f"- axctl context get '{key}' --space-id {space_id} --json")
return "\n".join(lines) if len(lines) > 1 else None


def _enrich_prompt_for_agent(
prompt: str,
*,
metadata: dict[str, Any] | None,
attachments: list[dict[str, Any]] | None,
space_id: str,
) -> str:
blocks = [
block
for block in (
_format_shared_object(metadata, space_id=space_id),
_format_attachments(attachments, space_id=space_id),
)
if block
]
if not blocks:
return prompt
return prompt.rstrip() + "\n\n---\n" + "\n\n".join(blocks)


class ChannelBridge:
Expand Down Expand Up @@ -180,6 +275,10 @@ async def emit_mentions(self) -> None:
meta["parent_id"] = event.parent_id
if event.attachments:
meta["attachments"] = event.attachments
if isinstance(event.metadata, dict):
forward = event.metadata.get("forward")
if isinstance(forward, dict):
meta["forward"] = forward
await self.send_notification(
"notifications/claude/channel",
{
Expand Down Expand Up @@ -273,6 +372,8 @@ async def handle_get_messages(self, request_id: Any, arguments: dict[str, Any])
"content": event.prompt,
"parent_id": event.parent_id,
"ts": event.created_at,
"attachments": event.attachments or [],
"metadata": event.metadata or {},
}
for event in pending
],
Expand Down Expand Up @@ -526,12 +627,14 @@ def _sse_loop(bridge: ChannelBridge) -> None:
or (author_raw if isinstance(author_raw, str) else "unknown")
)

# Extract attachment metadata. SSE events often omit
# the full metadata.attachments that the REST API returns,
# so we first check the SSE payload and fall back to a
# lightweight GET /messages/{id} call when needed.
# Extract share + attachment metadata. SSE events often
# omit the full metadata.attachments that the REST API
# returns, so we first check the SSE payload and fall back
# to a lightweight GET /messages/{id} call when needed.
attachments = None
msg_metadata = data.get("metadata") or {}
if not isinstance(msg_metadata, dict):
msg_metadata = {}
if isinstance(msg_metadata, dict):
raw_attachments = msg_metadata.get("attachments") or msg_metadata.get("accepted_attachments")
if raw_attachments and isinstance(raw_attachments, list):
Expand All @@ -547,12 +650,23 @@ def _sse_loop(bridge: ChannelBridge) -> None:
if isinstance(full_msg, dict):
full_msg = full_msg.get("message", full_msg)
full_meta = (full_msg or {}).get("metadata") or {}
if not isinstance(full_meta, dict):
full_meta = {}
merged_meta = dict(full_meta)
merged_meta.update(msg_metadata)
msg_metadata = merged_meta
api_attachments = full_meta.get("attachments") or full_meta.get("accepted_attachments")
if api_attachments and isinstance(api_attachments, list):
attachments = api_attachments
bridge.log(f" fetched {len(attachments)} attachment(s) from REST API")
except Exception as exc:
bridge.log(f" attachment fetch failed: {exc}")
prompt = _enrich_prompt_for_agent(
prompt,
metadata=msg_metadata,
attachments=attachments,
space_id=bridge.space_id,
)

bridge.enqueue_from_thread(
MentionEvent(
Expand All @@ -565,6 +679,7 @@ def _sse_loop(bridge: ChannelBridge) -> None:
created_at=data.get("created_at"),
space_id=bridge.space_id,
attachments=attachments,
metadata=msg_metadata,
)
)
if reconnect_after_event:
Expand Down
107 changes: 107 additions & 0 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,107 @@ def capture_delivery(event):
assert delivered[0].prompt == "please check this"


def test_channel_materializes_shared_task_metadata_for_agent_prompt(monkeypatch):
class FakeSseClient(FakeClient):
def connect_sse(self, *, space_id):
assert space_id == "space-123"
return FakeSseResponse(
{
"id": "incoming-share",
"content": "@anvil can you see what I shared?",
"author": {"id": "user-123", "name": "madtank", "type": "user"},
"mentions": ["anvil"],
"metadata": {
"forward": {
"intent": "share",
"resource_type": "task",
"resource_id": "task-123",
"task_id": "task-123",
"source_message_id": "source-msg-123",
"source_card_id": "task-signal:task-123",
"title": "Fix Share delivery context",
"summary": "The recipient should know this is a task.",
}
},
}
)

def get_message(self, message_id):
raise AssertionError("SSE metadata was already complete")

client = FakeSseClient()
bridge = CaptureBridge(client)
delivered: list[MentionEvent] = []

def capture_delivery(event):
delivered.append(event)
bridge.shutdown.set()

bridge.enqueue_from_thread = capture_delivery
monkeypatch.setattr(channel_mod.time, "monotonic", lambda: 0)

channel_mod._sse_loop(bridge)

assert [event.message_id for event in delivered] == ["incoming-share"]
assert "can you see what I shared?" in delivered[0].prompt
assert "Shared object:" in delivered[0].prompt
assert "- resource_type: task" in delivered[0].prompt
assert "- task_id: task-123" in delivered[0].prompt
assert "axctl tasks get task-123 --space-id space-123 --json" in delivered[0].prompt
assert delivered[0].metadata["forward"]["resource_type"] == "task"


def test_channel_fetches_attachment_metadata_and_adds_inspection_hint(monkeypatch):
class FakeSseClient(FakeClient):
def connect_sse(self, *, space_id):
assert space_id == "space-123"
return FakeSseResponse(
{
"id": "incoming-image",
"content": "@anvil please inspect this image",
"author": {"id": "user-123", "name": "madtank", "type": "user"},
"mentions": ["anvil"],
"metadata": {},
}
)

def get_message(self, message_id):
assert message_id == "incoming-image"
attachment = {
"id": "att-123",
"filename": "image.png",
"content_type": "image/png",
"context_key": "upload:image.png:att-123",
}
return {"message": {"metadata": {"accepted_attachments": [attachment]}}}

client = FakeSseClient()
bridge = CaptureBridge(client)
delivered: list[MentionEvent] = []

def capture_delivery(event):
delivered.append(event)
bridge.shutdown.set()

bridge.enqueue_from_thread = capture_delivery
monkeypatch.setattr(channel_mod.time, "monotonic", lambda: 0)

channel_mod._sse_loop(bridge)

assert [event.message_id for event in delivered] == ["incoming-image"]
assert "Attachments:" in delivered[0].prompt
assert "image.png (image/png, id=att-123, context_key=upload:image.png:att-123)" in delivered[0].prompt
assert "axctl context get 'upload:image.png:att-123' --space-id space-123 --json" in delivered[0].prompt
assert delivered[0].attachments == [
{
"id": "att-123",
"filename": "image.png",
"content_type": "image/png",
"context_key": "upload:image.png:att-123",
}
]


def test_channel_processing_status_can_be_disabled():
client = FakeClient("axp_a_AgentKey.Secret")
bridge = CaptureBridge(client, processing_status=False)
Expand Down Expand Up @@ -217,6 +318,8 @@ def test_channel_get_messages_returns_pending_mentions():
raw_content="@anvil please check this",
created_at="2026-04-15T23:00:00Z",
space_id="space-123",
attachments=[{"id": "att-1", "filename": "notes.md"}],
metadata={"forward": {"resource_type": "context"}},
)
)

Expand All @@ -225,6 +328,8 @@ def test_channel_get_messages_returns_pending_mentions():
result = bridge.writes[0]["result"]
assert "incoming-123" in result["content"][0]["text"]
assert "please check this" in result["content"][0]["text"]
assert "notes.md" in result["content"][0]["text"]
assert "resource_type" in result["content"][0]["text"]
assert bridge._pending_mentions == []


Expand All @@ -243,6 +348,7 @@ async def run():
raw_content="@anvil please check this",
created_at=None,
space_id="space-123",
metadata={"forward": {"resource_type": "task", "task_id": "task-123"}},
)
)
task = asyncio.create_task(bridge.emit_mentions())
Expand All @@ -265,6 +371,7 @@ async def run():
assert "raw_content" not in meta
assert "conversation_id" not in meta
assert "parent_id" not in meta
assert meta["forward"] == {"resource_type": "task", "task_id": "task-123"}


def test_channel_env_file_sets_missing_runtime_env(monkeypatch, tmp_path):
Expand Down
Loading