diff --git a/ax_cli/commands/channel.py b/ax_cli/commands/channel.py index 5c97bcf..2424c4b 100644 --- a/ax_cli/commands/channel.py +++ b/ax_cli/commands/channel.py @@ -59,6 +59,101 @@ class MentionEvent: created_at: str | None space_id: str attachments: list[dict[str, Any]] | None = None + metadata: dict[str, Any] | None = None + + +def _string_value(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, (str, int, float)): + text = str(value).strip() + return text or None + return None + + +def _format_shared_object(metadata: dict[str, Any] | None, *, space_id: str) -> str | None: + if not isinstance(metadata, dict): + return None + forward = metadata.get("forward") + if not isinstance(forward, dict): + return None + + fields = [ + ("resource_type", "resource_type"), + ("resource_id", "resource_id"), + ("task_id", "task_id"), + ("context_key", "context_key"), + ("resource_uri", "resource_uri"), + ("source_message_id", "source_message_id"), + ("source_card_id", "source_card_id"), + ("title", "title"), + ] + lines = ["Shared object:"] + for label, key in fields: + value = _string_value(forward.get(key)) + if value: + lines.append(f"- {label}: {value}") + + summary = _string_value(forward.get("summary")) + if summary: + lines.append(f"- summary: {summary}") + + task_id = _string_value(forward.get("task_id")) + context_key = _string_value(forward.get("context_key")) + if task_id or context_key: + lines.append("") + lines.append("Suggested inspection:") + if task_id: + lines.append(f"- axctl tasks get {task_id} --space-id {space_id} --json") + if context_key: + lines.append(f"- axctl context get '{context_key}' --space-id {space_id} --json") + + return "\n".join(lines) if len(lines) > 1 else None + + +def _format_attachments(attachments: list[dict[str, Any]] | None, *, space_id: str) -> str | None: + if not attachments: + return None + lines = ["Attachments:"] + context_keys: list[str] = [] + for attachment in attachments: + if not isinstance(attachment, dict): + continue + filename = _string_value(attachment.get("filename")) or "attachment" + content_type = _string_value(attachment.get("content_type")) + attachment_id = _string_value(attachment.get("id") or attachment.get("attachment_id")) + context_key = _string_value(attachment.get("context_key") or attachment.get("key")) + details = [part for part in (content_type, f"id={attachment_id}" if attachment_id else None) if part] + if context_key: + details.append(f"context_key={context_key}") + context_keys.append(context_key) + lines.append(f"- {filename}" + (f" ({', '.join(details)})" if details else "")) + if context_keys: + lines.append("") + lines.append("Suggested attachment inspection:") + for key in context_keys: + lines.append(f"- axctl context get '{key}' --space-id {space_id} --json") + return "\n".join(lines) if len(lines) > 1 else None + + +def _enrich_prompt_for_agent( + prompt: str, + *, + metadata: dict[str, Any] | None, + attachments: list[dict[str, Any]] | None, + space_id: str, +) -> str: + blocks = [ + block + for block in ( + _format_shared_object(metadata, space_id=space_id), + _format_attachments(attachments, space_id=space_id), + ) + if block + ] + if not blocks: + return prompt + return prompt.rstrip() + "\n\n---\n" + "\n\n".join(blocks) class ChannelBridge: @@ -180,6 +275,10 @@ async def emit_mentions(self) -> None: meta["parent_id"] = event.parent_id if event.attachments: meta["attachments"] = event.attachments + if isinstance(event.metadata, dict): + forward = event.metadata.get("forward") + if isinstance(forward, dict): + meta["forward"] = forward await self.send_notification( "notifications/claude/channel", { @@ -273,6 +372,8 @@ async def handle_get_messages(self, request_id: Any, arguments: dict[str, Any]) "content": event.prompt, "parent_id": event.parent_id, "ts": event.created_at, + "attachments": event.attachments or [], + "metadata": event.metadata or {}, } for event in pending ], @@ -526,12 +627,14 @@ def _sse_loop(bridge: ChannelBridge) -> None: or (author_raw if isinstance(author_raw, str) else "unknown") ) - # Extract attachment metadata. SSE events often omit - # the full metadata.attachments that the REST API returns, - # so we first check the SSE payload and fall back to a - # lightweight GET /messages/{id} call when needed. + # Extract share + attachment metadata. SSE events often + # omit the full metadata.attachments that the REST API + # returns, so we first check the SSE payload and fall back + # to a lightweight GET /messages/{id} call when needed. attachments = None msg_metadata = data.get("metadata") or {} + if not isinstance(msg_metadata, dict): + msg_metadata = {} if isinstance(msg_metadata, dict): raw_attachments = msg_metadata.get("attachments") or msg_metadata.get("accepted_attachments") if raw_attachments and isinstance(raw_attachments, list): @@ -547,12 +650,23 @@ def _sse_loop(bridge: ChannelBridge) -> None: if isinstance(full_msg, dict): full_msg = full_msg.get("message", full_msg) full_meta = (full_msg or {}).get("metadata") or {} + if not isinstance(full_meta, dict): + full_meta = {} + merged_meta = dict(full_meta) + merged_meta.update(msg_metadata) + msg_metadata = merged_meta api_attachments = full_meta.get("attachments") or full_meta.get("accepted_attachments") if api_attachments and isinstance(api_attachments, list): attachments = api_attachments bridge.log(f" fetched {len(attachments)} attachment(s) from REST API") except Exception as exc: bridge.log(f" attachment fetch failed: {exc}") + prompt = _enrich_prompt_for_agent( + prompt, + metadata=msg_metadata, + attachments=attachments, + space_id=bridge.space_id, + ) bridge.enqueue_from_thread( MentionEvent( @@ -565,6 +679,7 @@ def _sse_loop(bridge: ChannelBridge) -> None: created_at=data.get("created_at"), space_id=bridge.space_id, attachments=attachments, + metadata=msg_metadata, ) ) if reconnect_after_event: diff --git a/tests/test_channel.py b/tests/test_channel.py index edf5a37..f8bca4f 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -170,6 +170,107 @@ def capture_delivery(event): assert delivered[0].prompt == "please check this" +def test_channel_materializes_shared_task_metadata_for_agent_prompt(monkeypatch): + class FakeSseClient(FakeClient): + def connect_sse(self, *, space_id): + assert space_id == "space-123" + return FakeSseResponse( + { + "id": "incoming-share", + "content": "@anvil can you see what I shared?", + "author": {"id": "user-123", "name": "madtank", "type": "user"}, + "mentions": ["anvil"], + "metadata": { + "forward": { + "intent": "share", + "resource_type": "task", + "resource_id": "task-123", + "task_id": "task-123", + "source_message_id": "source-msg-123", + "source_card_id": "task-signal:task-123", + "title": "Fix Share delivery context", + "summary": "The recipient should know this is a task.", + } + }, + } + ) + + def get_message(self, message_id): + raise AssertionError("SSE metadata was already complete") + + client = FakeSseClient() + bridge = CaptureBridge(client) + delivered: list[MentionEvent] = [] + + def capture_delivery(event): + delivered.append(event) + bridge.shutdown.set() + + bridge.enqueue_from_thread = capture_delivery + monkeypatch.setattr(channel_mod.time, "monotonic", lambda: 0) + + channel_mod._sse_loop(bridge) + + assert [event.message_id for event in delivered] == ["incoming-share"] + assert "can you see what I shared?" in delivered[0].prompt + assert "Shared object:" in delivered[0].prompt + assert "- resource_type: task" in delivered[0].prompt + assert "- task_id: task-123" in delivered[0].prompt + assert "axctl tasks get task-123 --space-id space-123 --json" in delivered[0].prompt + assert delivered[0].metadata["forward"]["resource_type"] == "task" + + +def test_channel_fetches_attachment_metadata_and_adds_inspection_hint(monkeypatch): + class FakeSseClient(FakeClient): + def connect_sse(self, *, space_id): + assert space_id == "space-123" + return FakeSseResponse( + { + "id": "incoming-image", + "content": "@anvil please inspect this image", + "author": {"id": "user-123", "name": "madtank", "type": "user"}, + "mentions": ["anvil"], + "metadata": {}, + } + ) + + def get_message(self, message_id): + assert message_id == "incoming-image" + attachment = { + "id": "att-123", + "filename": "image.png", + "content_type": "image/png", + "context_key": "upload:image.png:att-123", + } + return {"message": {"metadata": {"accepted_attachments": [attachment]}}} + + client = FakeSseClient() + bridge = CaptureBridge(client) + delivered: list[MentionEvent] = [] + + def capture_delivery(event): + delivered.append(event) + bridge.shutdown.set() + + bridge.enqueue_from_thread = capture_delivery + monkeypatch.setattr(channel_mod.time, "monotonic", lambda: 0) + + channel_mod._sse_loop(bridge) + + assert [event.message_id for event in delivered] == ["incoming-image"] + assert "Attachments:" in delivered[0].prompt + assert "image.png (image/png, id=att-123, context_key=upload:image.png:att-123)" in delivered[0].prompt + assert "axctl context get 'upload:image.png:att-123' --space-id space-123 --json" in delivered[0].prompt + assert delivered[0].attachments == [ + { + "id": "att-123", + "filename": "image.png", + "content_type": "image/png", + "context_key": "upload:image.png:att-123", + } + ] + + def test_channel_processing_status_can_be_disabled(): client = FakeClient("axp_a_AgentKey.Secret") bridge = CaptureBridge(client, processing_status=False) @@ -217,6 +318,8 @@ def test_channel_get_messages_returns_pending_mentions(): raw_content="@anvil please check this", created_at="2026-04-15T23:00:00Z", space_id="space-123", + attachments=[{"id": "att-1", "filename": "notes.md"}], + metadata={"forward": {"resource_type": "context"}}, ) ) @@ -225,6 +328,8 @@ def test_channel_get_messages_returns_pending_mentions(): result = bridge.writes[0]["result"] assert "incoming-123" in result["content"][0]["text"] assert "please check this" in result["content"][0]["text"] + assert "notes.md" in result["content"][0]["text"] + assert "resource_type" in result["content"][0]["text"] assert bridge._pending_mentions == [] @@ -243,6 +348,7 @@ async def run(): raw_content="@anvil please check this", created_at=None, space_id="space-123", + metadata={"forward": {"resource_type": "task", "task_id": "task-123"}}, ) ) task = asyncio.create_task(bridge.emit_mentions()) @@ -265,6 +371,7 @@ async def run(): assert "raw_content" not in meta assert "conversation_id" not in meta assert "parent_id" not in meta + assert meta["forward"] == {"resource_type": "task", "task_id": "task-123"} def test_channel_env_file_sets_missing_runtime_env(monkeypatch, tmp_path):