From a26c53208eb46c38f127a0994f90ccd350915376 Mon Sep 17 00:00:00 2001 From: Haoyuan Li Date: Thu, 26 Mar 2026 10:25:39 +0800 Subject: [PATCH 1/6] feat: Add "typing" state control for weixin_oc plateform --- .../method/agent_sub_stages/internal.py | 12 +- astrbot/core/platform/astr_message_event.py | 6 + .../sources/weixin_oc/weixin_oc_adapter.py | 330 +++++++++++- .../sources/weixin_oc/weixin_oc_client.py | 41 ++ .../sources/weixin_oc/weixin_oc_event.py | 19 + tests/unit/test_astr_message_event.py | 11 + tests/unit/test_internal_agent_sub_stage.py | 297 +++++++++++ tests/unit/test_weixin_oc_typing.py | 480 ++++++++++++++++++ 8 files changed, 1194 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_internal_agent_sub_stage.py create mode 100644 tests/unit/test_weixin_oc_typing.py diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..3620124ea3 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -144,6 +144,7 @@ async def process( follow_up_capture: FollowUpCapture | None = None follow_up_consumed_marked = False follow_up_activated = False + typing_requested = False try: streaming_response = self.streaming_response if (enable_streaming := event.get_extra("enable_streaming")) is not None: @@ -178,7 +179,11 @@ async def process( ) return - await event.send_typing() + try: + typing_requested = True + await event.send_typing() + except Exception as e: + logger.warning("send_typing failed: %s", e) await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) async with session_lock_manager.acquire_lock(event.unified_msg_origin): @@ -377,6 +382,11 @@ async def process( ) await event.send(MessageChain().message(error_text)) finally: + if typing_requested and not event.platform_meta.support_streaming_message: + try: + await event.stop_typing() + except Exception as e: + logger.warning("stop_typing failed: %s", e) if follow_up_capture: await finalize_follow_up_capture( follow_up_capture, diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 82c03dbb0d..0ecd47fedc 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -293,6 +293,12 @@ async def send_typing(self) -> None: 默认实现为空,由具体平台按需重写。 """ + async def stop_typing(self) -> None: + """停止输入中状态。 + + 默认实现为空,由具体平台按需重写。 + """ + async def _pre_send(self) -> None: """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18""" diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py index c47b58087e..09896ef7f9 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py @@ -6,7 +6,7 @@ import io import time import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, cast from urllib.parse import quote @@ -49,6 +49,17 @@ class OpenClawLoginSession: error: str | None = None +@dataclass +class TypingSessionState: + ticket: str | None = None + ticket_context_token: str | None = None + refresh_after: float = 0.0 + keepalive_task: asyncio.Task | None = None + cancel_task: asyncio.Task | None = None + owners: set[str] = field(default_factory=set) + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + @register_platform_adapter( "weixin_oc", "个人微信", @@ -105,7 +116,16 @@ def __init__( self._sync_buf = "" self._qr_expired_count = 0 self._context_tokens: dict[str, str] = {} + self._typing_states: dict[str, TypingSessionState] = {} self._last_inbound_error = "" + self._typing_keepalive_interval_s = max( + 1, + int(platform_config.get("weixin_oc_typing_keepalive_interval", 5)), + ) + self._typing_ticket_ttl_s = max( + 5, + int(platform_config.get("weixin_oc_typing_ticket_ttl", 60)), + ) self.token = str(platform_config.get("weixin_oc_token", "")).strip() or None self.account_id = ( @@ -132,6 +152,312 @@ def _sync_client_state(self) -> None: self.client.api_timeout_ms = self.api_timeout_ms self.client.token = self.token + def _get_typing_state(self, user_id: str) -> TypingSessionState: + state = self._typing_states.get(user_id) + if state is None: + state = TypingSessionState() + self._typing_states[user_id] = state + return state + + def _typing_supported_for(self, user_id: str) -> bool: + if not self.token: + return False + return bool(self._context_tokens.get(user_id)) + + async def _ensure_typing_ticket( + self, + user_id: str, + state: TypingSessionState, + ) -> str | None: + now = time.monotonic() + context_token = self._context_tokens.get(user_id) + if not context_token: + return None + + if ( + state.ticket + and state.ticket_context_token == context_token + and state.refresh_after > now + ): + return state.ticket + + payload = await self.client.get_typing_config(user_id, context_token) + if int(payload.get("ret") or 0) != 0: + logger.warning( + "weixin_oc(%s): getconfig failed for %s: %s", + self.meta().id, + user_id, + payload.get("errmsg", ""), + ) + return None + + ticket = str(payload.get("typing_ticket", "")).strip() + if not ticket: + return None + + state.ticket = ticket + state.ticket_context_token = context_token + state.refresh_after = time.monotonic() + self._typing_ticket_ttl_s + return ticket + + async def _send_typing_state( + self, + user_id: str, + ticket: str, + *, + cancel: bool, + ) -> None: + payload = await self.client.send_typing_state(user_id, ticket, cancel=cancel) + if int(payload.get("ret") or 0) != 0: + raise RuntimeError( + f"sendtyping failed for {user_id}: {payload.get('errmsg', '')}" + ) + + async def _run_typing_keepalive(self, user_id: str) -> None: + restart_needed = False + try: + await self._typing_keepalive_loop(user_id) + except asyncio.CancelledError: + raise + except Exception as e: + state = self._typing_states.get(user_id) + if state is not None: + async with state.lock: + state.refresh_after = 0.0 + restart_needed = ( + bool(state.owners) and not self._shutdown_event.is_set() + ) + logger.warning( + "weixin_oc(%s): typing keepalive failed for %s: %s", + self.meta().id, + user_id, + e, + ) + finally: + state = self._typing_states.get(user_id) + current_task = asyncio.current_task() + if state is not None and state.keepalive_task is current_task: + state.keepalive_task = None + + if not restart_needed: + return + + await asyncio.sleep(self._typing_keepalive_interval_s) + state = self._typing_states.get(user_id) + if state is None or self._shutdown_event.is_set(): + return + + async with state.lock: + if not state.owners or state.keepalive_task is not None: + return + state.keepalive_task = asyncio.create_task( + self._run_typing_keepalive(user_id) + ) + + async def _typing_keepalive_loop(self, user_id: str) -> None: + while not self._shutdown_event.is_set(): + await asyncio.sleep(self._typing_keepalive_interval_s) + state = self._typing_states.get(user_id) + if state is None: + return + + async with state.lock: + if not state.owners: + return + try: + ticket = await self._ensure_typing_ticket(user_id, state) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): refresh typing ticket failed for %s: %s", + self.meta().id, + user_id, + e, + ) + continue + if not ticket: + continue + try: + await self._send_typing_state(user_id, ticket, cancel=False) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): typing keepalive send failed for %s: %s", + self.meta().id, + user_id, + e, + ) + + async def _delayed_cancel_typing(self, user_id: str, ticket: str) -> None: + await asyncio.sleep(0) + state = self._typing_states.get(user_id) + if state is None: + return + + current_task = asyncio.current_task() + async with state.lock: + if state.cancel_task is not current_task: + return + if state.owners or state.keepalive_task is not None: + state.cancel_task = None + return + + try: + await self._send_typing_state(user_id, ticket, cancel=True) + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "weixin_oc(%s): cancel typing failed for %s: %s", + self.meta().id, + user_id, + e, + ) + finally: + state = self._typing_states.get(user_id) + if state is None: + return + async with state.lock: + if state.cancel_task is current_task: + state.cancel_task = None + + async def start_typing(self, user_id: str, owner_id: str) -> None: + state = self._get_typing_state(user_id) + cancel_task: asyncio.Task | None = None + async with state.lock: + if owner_id in state.owners: + return + if not self._typing_supported_for(user_id): + return + if state.cancel_task is not None and not state.cancel_task.done(): + cancel_task = state.cancel_task + cancel_task.cancel() + state.cancel_task = None + try: + ticket = await self._ensure_typing_ticket(user_id, state) + except Exception as e: + logger.warning( + "weixin_oc(%s): ensure typing ticket failed for %s: %s", + self.meta().id, + user_id, + e, + ) + return + if not ticket: + return + + state.ticket = ticket + state.owners.add(owner_id) + if state.keepalive_task is not None and not state.keepalive_task.done(): + return + + try: + await self._send_typing_state(user_id, ticket, cancel=False) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): send typing failed for %s: %s", + self.meta().id, + user_id, + e, + ) + + task = asyncio.create_task(self._run_typing_keepalive(user_id)) + state.keepalive_task = task + + if cancel_task is not None: + try: + await cancel_task + except asyncio.CancelledError: + pass + except Exception: + pass + + async def stop_typing(self, user_id: str, owner_id: str) -> None: + state = self._typing_states.get(user_id) + if state is None: + return + + task: asyncio.Task | None = None + async with state.lock: + if owner_id in state.owners: + state.owners.remove(owner_id) + elif state.owners: + return + else: + return + + if state.owners: + return + + task = state.keepalive_task + state.keepalive_task = None + + if task is not None and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning( + "weixin_oc(%s): typing keepalive stop failed for %s: %s", + self.meta().id, + user_id, + e, + ) + + async with state.lock: + if state.owners: + return + ticket = state.ticket + if ticket: + if state.cancel_task is None or state.cancel_task.done(): + state.cancel_task = asyncio.create_task( + self._delayed_cancel_typing(user_id, ticket) + ) + + async def _cleanup_typing_tasks(self) -> None: + tasks: list[asyncio.Task] = [] + cancels: list[tuple[str, str]] = [] + for user_id, state in self._typing_states.items(): + if state.ticket and ( + state.owners + or state.keepalive_task is not None + or state.cancel_task is not None + ): + cancels.append((user_id, state.ticket)) + state.owners.clear() + if state.keepalive_task is not None and not state.keepalive_task.done(): + tasks.append(state.keepalive_task) + state.keepalive_task.cancel() + state.keepalive_task = None + if state.cancel_task is not None and not state.cancel_task.done(): + tasks.append(state.cancel_task) + state.cancel_task.cancel() + state.cancel_task = None + + for task in tasks: + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning( + "weixin_oc(%s): typing cleanup failed: %s", self.meta().id, e + ) + + for user_id, ticket in cancels: + try: + await self._send_typing_state(user_id, ticket, cancel=True) + except Exception as e: + logger.warning( + "weixin_oc(%s): typing cleanup cancel failed for %s: %s", + self.meta().id, + user_id, + e, + ) + def _load_account_state(self) -> None: if not self.token: token = str(self.config.get("weixin_oc_token", "")).strip() @@ -907,10 +1233,12 @@ async def run(self) -> None: except Exception as e: logger.exception("weixin_oc(%s): run failed: %s", self.meta().id, e) finally: + await self._cleanup_typing_tasks() await self.client.close() async def terminate(self) -> None: self._shutdown_event.set() + await self._cleanup_typing_tasks() def get_stats(self) -> dict: stat = super().get_stats() diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py index 5ea30d911c..51b0b6ed7c 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py @@ -226,3 +226,44 @@ async def request_json( if not text: return {} return cast(dict[str, Any], json.loads(text)) + + async def get_typing_config( + self, + user_id: str, + context_token: str, + ) -> dict[str, Any]: + return await self.request_json( + "POST", + "ilink/bot/getconfig", + payload={ + "ilink_user_id": user_id, + "context_token": context_token, + "base_info": { + "channel_version": "astrbot", + }, + }, + token_required=True, + timeout_ms=self.api_timeout_ms, + ) + + async def send_typing_state( + self, + user_id: str, + typing_ticket: str, + *, + cancel: bool, + ) -> dict[str, Any]: + return await self.request_json( + "POST", + "ilink/bot/sendtyping", + payload={ + "ilink_user_id": user_id, + "typing_ticket": typing_ticket, + "status": 2 if cancel else 1, + "base_info": { + "channel_version": "astrbot", + }, + }, + token_required=True, + timeout_ms=self.api_timeout_ms, + ) diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py index abe3b5a066..84a19a9e7b 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import uuid from typing import TYPE_CHECKING from astrbot.api.event import AstrMessageEvent, MessageChain @@ -29,6 +30,12 @@ def __init__( ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.platform = platform + self._typing_owner_id: str | None = None + + def _get_typing_owner_id(self) -> str: + if not self._typing_owner_id: + self._typing_owner_id = uuid.uuid4().hex + return self._typing_owner_id @staticmethod def _segment_to_text(segment: BaseMessageComponent) -> str: @@ -58,6 +65,18 @@ async def send(self, message: MessageChain) -> None: await self.platform.send_by_session(self.session, message) await super().send(message) + async def send_typing(self) -> None: + await self.platform.start_typing( + self.session.session_id, + self._get_typing_owner_id(), + ) + + async def stop_typing(self) -> None: + await self.platform.stop_typing( + self.session.session_id, + self._get_typing_owner_id(), + ) + async def send_streaming(self, generator, use_fallback: bool = False): if not use_fallback: buffer = None diff --git a/tests/unit/test_astr_message_event.py b/tests/unit/test_astr_message_event.py index ac529318fe..89087d1cab 100644 --- a/tests/unit/test_astr_message_event.py +++ b/tests/unit/test_astr_message_event.py @@ -651,6 +651,15 @@ async def test_send_typing_default_empty(self, astr_message_event): await astr_message_event.send_typing() +class TestStopTyping: + """Tests for stop_typing method.""" + + @pytest.mark.asyncio + async def test_stop_typing_default_empty(self, astr_message_event): + """Test stop_typing default implementation is empty.""" + await astr_message_event.stop_typing() + + class TestReact: """Tests for react method.""" @@ -772,10 +781,12 @@ def test_get_sender_fields_without_sender_attr(self, astr_message_event): def test_get_message_type_with_non_enum_type(self, astr_message_event): """get_message_type should handle message_obj.type that is not a MessageType.""" + class DummyMessage: def __init__(self): self.type = "not_an_enum" self.message = [] + astr_message_event.message_obj = DummyMessage() message_type = astr_message_event.get_message_type() assert isinstance(message_type, MessageType) diff --git a/tests/unit/test_internal_agent_sub_stage.py b/tests/unit/test_internal_agent_sub_stage.py new file mode 100644 index 0000000000..95722e2936 --- /dev/null +++ b/tests/unit/test_internal_agent_sub_stage.py @@ -0,0 +1,297 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.message.components import Plain +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember +from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.pipeline.process_stage.method.agent_sub_stages import ( + internal as internal_module, +) +from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import ( + InternalAgentSubStage, +) + + +class ConcreteAstrMessageEvent(AstrMessageEvent): + async def send(self, message): + await super().send(message) + + +@pytest.fixture +def mock_ctx(): + plugin_context = MagicMock() + plugin_context.conversation_manager = MagicMock() + plugin_context.get_config.return_value = {"timezone": "UTC"} + plugin_context.get_using_tts_provider.return_value = None + + ctx = MagicMock() + ctx.astrbot_config = { + "provider_settings": { + "streaming_response": False, + "unsupported_streaming_strategy": "turn_off", + "max_context_length": 32, + "dequeue_context_length": 4, + }, + "kb_agentic_mode": False, + "subagent_orchestrator": {}, + } + ctx.plugin_manager.context = plugin_context + return ctx + + +@pytest.fixture +def stage(mock_ctx): + async def _make_stage(): + obj = InternalAgentSubStage() + await obj.initialize(mock_ctx) + obj._save_to_history = AsyncMock() + return obj + + return _make_stage + + +@pytest.fixture +def event(): + platform_meta = PlatformMetadata( + name="test_platform", + description="Test platform", + id="test_platform_id", + support_streaming_message=False, + ) + message = AstrBotMessage() + message.type = MessageType.FRIEND_MESSAGE + message.self_id = "bot123" + message.session_id = "session123" + message.message_id = "msg123" + message.sender = MessageMember(user_id="user123", nickname="TestUser") + message.message = [Plain(text="Hello world")] + message.message_str = "Hello world" + message.raw_message = None + return ConcreteAstrMessageEvent( + message_str="Hello world", + message_obj=message, + platform_meta=platform_meta, + session_id="session123", + ) + + +@asynccontextmanager +async def fake_lock(_umo): + yield + + +def make_build_result() -> SimpleNamespace: + provider = MagicMock() + provider.provider_config = {"id": "provider-1", "api_base": ""} + provider.get_model.return_value = "test-model" + provider.meta.return_value = SimpleNamespace(type="test") + + final_resp = SimpleNamespace( + completion_text="done", + result_chain=None, + role="assistant", + usage=None, + ) + agent_runner = MagicMock() + agent_runner.done.return_value = True + agent_runner.was_aborted.return_value = False + agent_runner.get_final_llm_resp.return_value = final_resp + agent_runner.run_context = SimpleNamespace(messages=[]) + agent_runner.stats = MagicMock() + agent_runner.stats.to_dict.return_value = {} + agent_runner.provider = provider + + return SimpleNamespace( + agent_runner=agent_runner, + provider_request=SimpleNamespace( + system_prompt="sys", + func_tool=None, + conversation=object(), + tool_calls_result=None, + ), + provider=provider, + reset_coro=None, + ) + + +async def empty_run_agent(*args, **kwargs): + if False: + yield None + + +@pytest.mark.asyncio +async def test_process_swallows_send_typing_error_and_still_releases(stage, event): + event.send_typing = AsyncMock(side_effect=RuntimeError("boom")) + event.stop_typing = AsyncMock() + obj = await stage() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_releases_typing_when_build_returns_none(stage, event): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + obj = await stage() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.send_typing.assert_awaited_once() + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_releases_typing_when_llm_request_hook_short_circuits( + stage, event +): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + obj = await stage() + build_result = make_build_result() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object( + internal_module, + "call_event_hook", + AsyncMock(side_effect=[False, True]), + ), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object( + internal_module, + "build_main_agent", + AsyncMock(return_value=build_result), + ), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_releases_typing_after_normal_reply(stage, event): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + obj = await stage() + build_result = make_build_result() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object( + internal_module, + "call_event_hook", + AsyncMock(side_effect=[False, False]), + ), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object( + internal_module, + "build_main_agent", + AsyncMock(return_value=build_result), + ), + patch.object(internal_module, "run_agent", empty_run_agent), + patch.object(internal_module, "register_active_runner"), + patch.object(internal_module, "unregister_active_runner"), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_does_not_stop_typing_early_for_streaming_platforms(stage, event): + event.platform_meta.support_streaming_message = True + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + obj = await stage() + obj.streaming_response = True + build_result = make_build_result() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object( + internal_module, + "call_event_hook", + AsyncMock(side_effect=[False, False]), + ), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object( + internal_module, + "build_main_agent", + AsyncMock(return_value=build_result), + ), + patch.object(internal_module, "run_agent", empty_run_agent), + patch.object(internal_module, "register_active_runner"), + patch.object(internal_module, "unregister_active_runner"), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert len(results) == 1 + event.stop_typing.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_process_releases_typing_on_error_fallback_send(stage, event): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock() + event.send = AsyncMock() + obj = await stage() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object( + internal_module, + "build_main_agent", + AsyncMock(side_effect=RuntimeError("boom")), + ), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.send.assert_awaited_once() + event.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_process_swallows_stop_typing_error(stage, event): + event.send_typing = AsyncMock() + event.stop_typing = AsyncMock(side_effect=RuntimeError("stop failed")) + obj = await stage() + + with ( + patch.object(internal_module, "try_capture_follow_up", return_value=None), + patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), + patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), + patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), + ): + results = [item async for item in obj.process(event, provider_wake_prefix="")] + + assert results == [] + event.send_typing.assert_awaited_once() + event.stop_typing.assert_awaited_once() diff --git a/tests/unit/test_weixin_oc_typing.py b/tests/unit/test_weixin_oc_typing.py new file mode 100644 index 0000000000..2ce3250747 --- /dev/null +++ b/tests/unit/test_weixin_oc_typing.py @@ -0,0 +1,480 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.core.message.components import Plain +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember +from astrbot.core.platform.message_type import MessageType +from astrbot.core.platform.platform_metadata import PlatformMetadata +from astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter import ( + TypingSessionState, + WeixinOCAdapter, +) +from astrbot.core.platform.sources.weixin_oc.weixin_oc_client import WeixinOCClient +from astrbot.core.platform.sources.weixin_oc.weixin_oc_event import WeixinOCMessageEvent + + +@pytest.fixture +def client(): + return WeixinOCClient( + adapter_id="wx-1", + base_url="https://example.com", + cdn_base_url="https://cdn.example.com", + api_timeout_ms=15000, + token="token-1", + ) + + +@pytest.fixture +def adapter(): + obj = WeixinOCAdapter( + platform_config={ + "id": "wx-1", + "type": "weixin_oc", + "weixin_oc_token": "token-1", + }, + platform_settings={}, + event_queue=asyncio.Queue(), + ) + obj._context_tokens["user-1"] = "ctx-1" + return obj + + +@pytest.fixture +def weixin_event(): + message = AstrBotMessage() + message.type = MessageType.FRIEND_MESSAGE + message.self_id = "bot123" + message.session_id = "user-1" + message.message_id = "msg123" + message.sender = MessageMember(user_id="user-1", nickname="User") + message.message = [Plain(text="hello")] + message.message_str = "hello" + message.raw_message = None + + platform = MagicMock() + platform.start_typing = AsyncMock() + platform.stop_typing = AsyncMock() + platform.send_by_session = AsyncMock() + + event = WeixinOCMessageEvent( + message_str="hello", + message_obj=message, + platform_meta=PlatformMetadata( + name="weixin_oc", + description="个人微信", + id="wx-1", + support_streaming_message=False, + ), + session_id="user-1", + platform=platform, + ) + return event, platform + + +@pytest.mark.asyncio +async def test_get_typing_config_uses_getconfig(client): + client.request_json = AsyncMock(return_value={"typing_ticket": "ticket-1"}) + + result = await client.get_typing_config("user-1", "ctx-1") + + assert result == {"typing_ticket": "ticket-1"} + client.request_json.assert_awaited_once_with( + "POST", + "ilink/bot/getconfig", + payload={ + "ilink_user_id": "user-1", + "context_token": "ctx-1", + "base_info": {"channel_version": "astrbot"}, + }, + token_required=True, + timeout_ms=client.api_timeout_ms, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("cancel, status", [(False, 1), (True, 2)]) +async def test_send_typing_state_uses_sendtyping(client, cancel, status): + client.request_json = AsyncMock(return_value={}) + + await client.send_typing_state("user-1", "ticket-1", cancel=cancel) + + client.request_json.assert_awaited_once_with( + "POST", + "ilink/bot/sendtyping", + payload={ + "ilink_user_id": "user-1", + "typing_ticket": "ticket-1", + "status": status, + "base_info": {"channel_version": "astrbot"}, + }, + token_required=True, + timeout_ms=client.api_timeout_ms, + ) + + +@pytest.mark.asyncio +async def test_event_delegates_typing_calls(weixin_event): + event, platform = weixin_event + + await event.send_typing() + await event.stop_typing() + + platform.start_typing.assert_awaited_once() + platform.stop_typing.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_event_reuses_stable_owner_id(weixin_event): + event, platform = weixin_event + + await event.send_typing() + await event.stop_typing() + + start_owner = platform.start_typing.await_args.args[1] + stop_owner = platform.stop_typing.await_args.args[1] + assert start_owner == stop_owner + + +@pytest.mark.asyncio +async def test_start_typing_skips_without_token(adapter): + adapter.token = None + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + adapter._send_typing_state = AsyncMock() + + await adapter.start_typing("user-1", "owner-a") + + adapter._ensure_typing_ticket.assert_not_awaited() + adapter._send_typing_state.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_start_typing_skips_without_context_token(adapter): + adapter._context_tokens.clear() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + adapter._send_typing_state = AsyncMock() + + await adapter.start_typing("user-1", "owner-a") + + adapter._ensure_typing_ticket.assert_not_awaited() + adapter._send_typing_state.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_ensure_typing_ticket_reuses_fresh_ticket(adapter): + state = TypingSessionState( + ticket="cached-ticket", + ticket_context_token="ctx-1", + refresh_after=float("inf"), + ) + adapter.client.get_typing_config = AsyncMock() + + result = await adapter._ensure_typing_ticket("user-1", state) + + assert result == "cached-ticket" + adapter.client.get_typing_config.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_ensure_typing_ticket_refreshes_stale_ticket(adapter): + state = TypingSessionState(ticket="stale-ticket", refresh_after=0.0) + adapter.client.get_typing_config = AsyncMock( + return_value={"typing_ticket": "fresh-ticket"} + ) + + result = await adapter._ensure_typing_ticket("user-1", state) + + assert result == "fresh-ticket" + assert state.ticket == "fresh-ticket" + adapter.client.get_typing_config.assert_awaited_once_with("user-1", "ctx-1") + + +@pytest.mark.asyncio +async def test_ensure_typing_ticket_refreshes_when_context_token_changes(adapter): + state = TypingSessionState( + ticket="cached-ticket", + ticket_context_token="ctx-1", + refresh_after=float("inf"), + ) + adapter._context_tokens["user-1"] = "ctx-2" + adapter.client.get_typing_config = AsyncMock( + return_value={"typing_ticket": "fresh-ticket"} + ) + + result = await adapter._ensure_typing_ticket("user-1", state) + + assert result == "fresh-ticket" + assert state.ticket_context_token == "ctx-2" + adapter.client.get_typing_config.assert_awaited_once_with("user-1", "ctx-2") + + +@pytest.mark.asyncio +async def test_send_typing_state_raises_on_nonzero_ret(adapter): + adapter.client.send_typing_state = AsyncMock( + return_value={"ret": 1, "errmsg": "expired"} + ) + + with pytest.raises(RuntimeError, match="sendtyping failed"): + await adapter._send_typing_state("user-1", "ticket-1", cancel=False) + + +@pytest.mark.asyncio +async def test_start_typing_same_owner_is_idempotent(adapter): + stop_event = asyncio.Event() + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + await stop_event.wait() + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await adapter.start_typing("user-1", "owner-a") + + assert adapter._send_typing_state.await_count == 1 + state = adapter._typing_states["user-1"] + assert state.owners == {"owner-a"} + + await adapter.stop_typing("user-1", "owner-a") + stop_event.set() + + +@pytest.mark.asyncio +async def test_stop_typing_only_cancels_on_last_owner(adapter): + stop_event = asyncio.Event() + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + await stop_event.wait() + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await adapter.start_typing("user-1", "owner-b") + await adapter.stop_typing("user-1", "owner-a") + + state = adapter._typing_states["user-1"] + assert state.owners == {"owner-b"} + assert adapter._send_typing_state.await_count == 1 + + await adapter.stop_typing("user-1", "owner-b") + await asyncio.sleep(0) + await asyncio.sleep(0) + stop_event.set() + assert adapter._send_typing_state.await_count == 2 + + +@pytest.mark.asyncio +async def test_stop_typing_is_safe_to_repeat(adapter): + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + await asyncio.Event().wait() + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await adapter.stop_typing("user-1", "owner-a") + await adapter.stop_typing("user-1", "owner-a") + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert adapter._send_typing_state.await_count == 2 + + +@pytest.mark.asyncio +async def test_keepalive_failure_cleans_state(adapter): + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + raise RuntimeError("keepalive failed") + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await asyncio.sleep(0) + + state = adapter._typing_states["user-1"] + assert state.keepalive_task is None + + await adapter.stop_typing("user-1", "owner-a") + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert adapter._send_typing_state.await_count == 2 + + +@pytest.mark.asyncio +async def test_keepalive_failure_restarts_for_active_owner(adapter): + adapter._typing_keepalive_interval_s = 0 + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + keepalive_round = 0 + stop_event = asyncio.Event() + + async def fake_keepalive(_user_id): + nonlocal keepalive_round + keepalive_round += 1 + if keepalive_round == 1: + raise RuntimeError("keepalive failed") + await stop_event.wait() + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + for _ in range(4): + await asyncio.sleep(0) + + state = adapter._typing_states["user-1"] + assert keepalive_round >= 2 + assert state.keepalive_task is not None + + stop_event.set() + await adapter.stop_typing("user-1", "owner-a") + for _ in range(2): + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_stop_typing_does_not_cancel_new_owner_session(adapter): + cancel_blocked = asyncio.Event() + allow_cancel_exit = asyncio.Event() + adapter._send_typing_state = AsyncMock() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_keepalive(_user_id): + try: + await asyncio.Event().wait() + except asyncio.CancelledError: + cancel_blocked.set() + await allow_cancel_exit.wait() + raise + + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + stop_task = asyncio.create_task(adapter.stop_typing("user-1", "owner-a")) + await cancel_blocked.wait() + await adapter.start_typing("user-1", "owner-b") + allow_cancel_exit.set() + await stop_task + + assert adapter._send_typing_state.await_count == 2 + + +@pytest.mark.asyncio +async def test_start_typing_cancels_inflight_cancel_task(adapter): + cancel_started = asyncio.Event() + release_cancel = asyncio.Event() + stop_event = asyncio.Event() + events: list[str] = [] + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + + async def fake_send_typing_state(_user_id, ticket, *, cancel): + if cancel: + events.append("cancel-start") + cancel_started.set() + try: + await release_cancel.wait() + except asyncio.CancelledError: + events.append("cancel-cancelled") + raise + events.append("cancel-finished") + return + events.append(f"start-{ticket}") + + async def fake_keepalive(_user_id): + await stop_event.wait() + + adapter._send_typing_state = fake_send_typing_state + adapter._typing_keepalive_loop = fake_keepalive + + await adapter.start_typing("user-1", "owner-a") + await adapter.stop_typing("user-1", "owner-a") + await asyncio.sleep(0) + await asyncio.sleep(0) + await cancel_started.wait() + + start_task = asyncio.create_task(adapter.start_typing("user-1", "owner-b")) + await asyncio.sleep(0) + release_cancel.set() + await start_task + + assert "cancel-cancelled" in events + assert "cancel-finished" not in events + + stop_event.set() + await adapter.stop_typing("user-1", "owner-b") + await asyncio.sleep(0) + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_cleanup_typing_tasks_sends_final_cancel(adapter): + adapter._send_typing_state = AsyncMock() + + async def fake_keepalive(_user_id): + await asyncio.Event().wait() + + task = asyncio.create_task(fake_keepalive("user-1")) + adapter._typing_states["user-1"] = TypingSessionState( + ticket="ticket-1", + refresh_after=float("inf"), + keepalive_task=task, + owners={"owner-a"}, + ) + + await adapter._cleanup_typing_tasks() + + adapter._send_typing_state.assert_awaited_once_with( + "user-1", + "ticket-1", + cancel=True, + ) + + +@pytest.mark.asyncio +async def test_run_finally_cancels_keepalive_before_client_close(adapter): + order: list[str] = [] + task = asyncio.create_task(asyncio.Event().wait()) + adapter._typing_states["user-1"] = TypingSessionState( + ticket="ticket-1", + refresh_after=float("inf"), + keepalive_task=task, + owners={"owner-a"}, + ) + adapter._cleanup_typing_tasks = AsyncMock( + side_effect=lambda: order.append("cleanup") + ) + adapter.client.close = AsyncMock(side_effect=lambda: order.append("close")) + + with patch.object( + adapter, + "_poll_inbound_updates", + AsyncMock(side_effect=RuntimeError("boom")), + ): + await adapter.run() + + assert order == ["cleanup", "close"] + + +@pytest.mark.asyncio +async def test_send_still_works_with_existing_event_behavior(weixin_event): + event, platform = weixin_event + + with patch( + "astrbot.core.platform.astr_message_event.Metric.upload", + new_callable=AsyncMock, + ): + await event.send(MessageChain([Plain("reply")])) + + platform.send_by_session.assert_awaited_once() From 7c764a13e0d5236cc0080ff79424ba36e955dbe8 Mon Sep 17 00:00:00 2001 From: Haoyuan Li Date: Thu, 26 Mar 2026 11:12:09 +0800 Subject: [PATCH 2/6] fix: avoid typing state mutation during cleanup --- astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py index 09896ef7f9..44ad7063e5 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py @@ -420,7 +420,7 @@ async def stop_typing(self, user_id: str, owner_id: str) -> None: async def _cleanup_typing_tasks(self) -> None: tasks: list[asyncio.Task] = [] cancels: list[tuple[str, str]] = [] - for user_id, state in self._typing_states.items(): + for user_id, state in list(self._typing_states.items()): if state.ticket and ( state.owners or state.keepalive_task is not None From da06e23ad5a2c0f98160b066a66e1c6f501e72dd Mon Sep 17 00:00:00 2001 From: Haoyuan Li Date: Thu, 26 Mar 2026 11:22:23 +0800 Subject: [PATCH 3/6] fix: preserve typing error tracebacks in logs --- .../method/agent_sub_stages/internal.py | 8 ++-- .../sources/weixin_oc/weixin_oc_adapter.py | 6 ++- tests/unit/test_internal_agent_sub_stage.py | 4 ++ tests/unit/test_weixin_oc_typing.py | 40 +++++++++++++++++++ 4 files changed, 53 insertions(+), 5 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 3620124ea3..651721ec20 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -182,8 +182,8 @@ async def process( try: typing_requested = True await event.send_typing() - except Exception as e: - logger.warning("send_typing failed: %s", e) + except Exception: + logger.warning("send_typing failed", exc_info=True) await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) async with session_lock_manager.acquire_lock(event.unified_msg_origin): @@ -385,8 +385,8 @@ async def process( if typing_requested and not event.platform_meta.support_streaming_message: try: await event.stop_typing() - except Exception as e: - logger.warning("stop_typing failed: %s", e) + except Exception: + logger.warning("stop_typing failed", exc_info=True) if follow_up_capture: await finalize_follow_up_capture( follow_up_capture, diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py index 44ad7063e5..e1333b26d4 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py @@ -371,7 +371,11 @@ async def start_typing(self, user_id: str, owner_id: str) -> None: except asyncio.CancelledError: pass except Exception: - pass + logger.warning( + "weixin_oc(%s): ignored error from cancelled typing task", + self.meta().id, + exc_info=True, + ) async def stop_typing(self, user_id: str, owner_id: str) -> None: state = self._typing_states.get(user_id) diff --git a/tests/unit/test_internal_agent_sub_stage.py b/tests/unit/test_internal_agent_sub_stage.py index 95722e2936..69c4b1c296 100644 --- a/tests/unit/test_internal_agent_sub_stage.py +++ b/tests/unit/test_internal_agent_sub_stage.py @@ -133,6 +133,7 @@ async def test_process_swallows_send_typing_error_and_still_releases(stage, even obj = await stage() with ( + patch.object(internal_module.logger, "warning") as warning_mock, patch.object(internal_module, "try_capture_follow_up", return_value=None), patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), @@ -142,6 +143,7 @@ async def test_process_swallows_send_typing_error_and_still_releases(stage, even assert results == [] event.stop_typing.assert_awaited_once() + warning_mock.assert_called_once_with("send_typing failed", exc_info=True) @pytest.mark.asyncio @@ -285,6 +287,7 @@ async def test_process_swallows_stop_typing_error(stage, event): obj = await stage() with ( + patch.object(internal_module.logger, "warning") as warning_mock, patch.object(internal_module, "try_capture_follow_up", return_value=None), patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), @@ -295,3 +298,4 @@ async def test_process_swallows_stop_typing_error(stage, event): assert results == [] event.send_typing.assert_awaited_once() event.stop_typing.assert_awaited_once() + warning_mock.assert_called_once_with("stop_typing failed", exc_info=True) diff --git a/tests/unit/test_weixin_oc_typing.py b/tests/unit/test_weixin_oc_typing.py index 2ce3250747..9301be95ab 100644 --- a/tests/unit/test_weixin_oc_typing.py +++ b/tests/unit/test_weixin_oc_typing.py @@ -418,6 +418,46 @@ async def fake_keepalive(_user_id): await asyncio.sleep(0) +@pytest.mark.asyncio +async def test_start_typing_logs_ignored_cancel_task_errors(adapter): + stop_event = asyncio.Event() + adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") + state = adapter._get_typing_state("user-1") + + async def fake_send_typing_state(_user_id, _ticket, *, cancel): + return None + + async def fake_cancel_task(): + try: + await asyncio.Event().wait() + except asyncio.CancelledError as exc: + raise RuntimeError("cancel failed") from exc + + async def fake_keepalive(_user_id): + await stop_event.wait() + + adapter._send_typing_state = fake_send_typing_state + adapter._typing_keepalive_loop = fake_keepalive + state.cancel_task = asyncio.create_task(fake_cancel_task()) + await asyncio.sleep(0) + + with patch( + "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.logger.warning" + ) as warning_mock: + await adapter.start_typing("user-1", "owner-a") + + warning_mock.assert_called_once_with( + "weixin_oc(%s): ignored error from cancelled typing task", + adapter.meta().id, + exc_info=True, + ) + + stop_event.set() + await adapter.stop_typing("user-1", "owner-a") + await asyncio.sleep(0) + await asyncio.sleep(0) + + @pytest.mark.asyncio async def test_cleanup_typing_tasks_sends_final_cancel(adapter): adapter._send_typing_state = AsyncMock() From c2a4d6c90572fa4192881715d523fcd6e35404f9 Mon Sep 17 00:00:00 2001 From: Haoyuan Li Date: Thu, 26 Mar 2026 11:33:16 +0800 Subject: [PATCH 4/6] refactor: simplify typing task cancellation flow --- .../sources/weixin_oc/weixin_oc_adapter.py | 72 +++++++++---------- tests/unit/test_weixin_oc_typing.py | 27 +++++++ 2 files changed, 63 insertions(+), 36 deletions(-) diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py index e1333b26d4..67189a25da 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py @@ -164,6 +164,25 @@ def _typing_supported_for(self, user_id: str) -> bool: return False return bool(self._context_tokens.get(user_id)) + async def _cancel_task_safely( + self, + task: asyncio.Task | None, + *, + log_message: str | None = None, + log_args: tuple[Any, ...] = (), + ) -> None: + if task is None or task.done(): + return + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + if log_message is not None: + logger.warning(log_message, *log_args, exc_info=True) + async def _ensure_typing_ticket( self, user_id: str, @@ -366,16 +385,11 @@ async def start_typing(self, user_id: str, owner_id: str) -> None: state.keepalive_task = task if cancel_task is not None: - try: - await cancel_task - except asyncio.CancelledError: - pass - except Exception: - logger.warning( - "weixin_oc(%s): ignored error from cancelled typing task", - self.meta().id, - exc_info=True, - ) + await self._cancel_task_safely( + cancel_task, + log_message="weixin_oc(%s): ignored error from cancelled typing task", + log_args=(self.meta().id,), + ) async def stop_typing(self, user_id: str, owner_id: str) -> None: state = self._typing_states.get(user_id) @@ -384,12 +398,9 @@ async def stop_typing(self, user_id: str, owner_id: str) -> None: task: asyncio.Task | None = None async with state.lock: - if owner_id in state.owners: - state.owners.remove(owner_id) - elif state.owners: - return - else: + if owner_id not in state.owners: return + state.owners.remove(owner_id) if state.owners: return @@ -397,19 +408,11 @@ async def stop_typing(self, user_id: str, owner_id: str) -> None: task = state.keepalive_task state.keepalive_task = None - if task is not None and not task.done(): - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - except Exception as e: - logger.warning( - "weixin_oc(%s): typing keepalive stop failed for %s: %s", - self.meta().id, - user_id, - e, - ) + await self._cancel_task_safely( + task, + log_message="weixin_oc(%s): typing keepalive stop failed for %s", + log_args=(self.meta().id, user_id), + ) async with state.lock: if state.owners: @@ -442,14 +445,11 @@ async def _cleanup_typing_tasks(self) -> None: state.cancel_task = None for task in tasks: - try: - await task - except asyncio.CancelledError: - pass - except Exception as e: - logger.warning( - "weixin_oc(%s): typing cleanup failed: %s", self.meta().id, e - ) + await self._cancel_task_safely( + task, + log_message="weixin_oc(%s): typing cleanup failed", + log_args=(self.meta().id,), + ) for user_id, ticket in cancels: try: diff --git a/tests/unit/test_weixin_oc_typing.py b/tests/unit/test_weixin_oc_typing.py index 9301be95ab..26a270ff32 100644 --- a/tests/unit/test_weixin_oc_typing.py +++ b/tests/unit/test_weixin_oc_typing.py @@ -223,6 +223,33 @@ async def test_send_typing_state_raises_on_nonzero_ret(adapter): await adapter._send_typing_state("user-1", "ticket-1", cancel=False) +@pytest.mark.asyncio +async def test_cancel_task_safely_logs_task_errors(adapter): + async def failing_task(): + try: + await asyncio.Event().wait() + except asyncio.CancelledError as exc: + raise RuntimeError("task wait failed") from exc + + task = asyncio.create_task(failing_task()) + await asyncio.sleep(0) + + with patch( + "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.logger.warning" + ) as warning_mock: + await adapter._cancel_task_safely( + task, + log_message="weixin_oc(%s): typing cleanup failed", + log_args=(adapter.meta().id,), + ) + + warning_mock.assert_called_once_with( + "weixin_oc(%s): typing cleanup failed", + adapter.meta().id, + exc_info=True, + ) + + @pytest.mark.asyncio async def test_start_typing_same_owner_is_idempotent(adapter): stop_event = asyncio.Event() From db3659e8094aeb5ad41697d6db1ede62f7265596 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 27 Mar 2026 15:43:27 +0800 Subject: [PATCH 5/6] chore: remove tests --- tests/unit/test_internal_agent_sub_stage.py | 301 ----------- tests/unit/test_weixin_oc_typing.py | 547 -------------------- 2 files changed, 848 deletions(-) delete mode 100644 tests/unit/test_internal_agent_sub_stage.py delete mode 100644 tests/unit/test_weixin_oc_typing.py diff --git a/tests/unit/test_internal_agent_sub_stage.py b/tests/unit/test_internal_agent_sub_stage.py deleted file mode 100644 index 69c4b1c296..0000000000 --- a/tests/unit/test_internal_agent_sub_stage.py +++ /dev/null @@ -1,301 +0,0 @@ -from __future__ import annotations - -from contextlib import asynccontextmanager -from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from astrbot.core.message.components import Plain -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember -from astrbot.core.platform.message_type import MessageType -from astrbot.core.platform.platform_metadata import PlatformMetadata -from astrbot.core.pipeline.process_stage.method.agent_sub_stages import ( - internal as internal_module, -) -from astrbot.core.pipeline.process_stage.method.agent_sub_stages.internal import ( - InternalAgentSubStage, -) - - -class ConcreteAstrMessageEvent(AstrMessageEvent): - async def send(self, message): - await super().send(message) - - -@pytest.fixture -def mock_ctx(): - plugin_context = MagicMock() - plugin_context.conversation_manager = MagicMock() - plugin_context.get_config.return_value = {"timezone": "UTC"} - plugin_context.get_using_tts_provider.return_value = None - - ctx = MagicMock() - ctx.astrbot_config = { - "provider_settings": { - "streaming_response": False, - "unsupported_streaming_strategy": "turn_off", - "max_context_length": 32, - "dequeue_context_length": 4, - }, - "kb_agentic_mode": False, - "subagent_orchestrator": {}, - } - ctx.plugin_manager.context = plugin_context - return ctx - - -@pytest.fixture -def stage(mock_ctx): - async def _make_stage(): - obj = InternalAgentSubStage() - await obj.initialize(mock_ctx) - obj._save_to_history = AsyncMock() - return obj - - return _make_stage - - -@pytest.fixture -def event(): - platform_meta = PlatformMetadata( - name="test_platform", - description="Test platform", - id="test_platform_id", - support_streaming_message=False, - ) - message = AstrBotMessage() - message.type = MessageType.FRIEND_MESSAGE - message.self_id = "bot123" - message.session_id = "session123" - message.message_id = "msg123" - message.sender = MessageMember(user_id="user123", nickname="TestUser") - message.message = [Plain(text="Hello world")] - message.message_str = "Hello world" - message.raw_message = None - return ConcreteAstrMessageEvent( - message_str="Hello world", - message_obj=message, - platform_meta=platform_meta, - session_id="session123", - ) - - -@asynccontextmanager -async def fake_lock(_umo): - yield - - -def make_build_result() -> SimpleNamespace: - provider = MagicMock() - provider.provider_config = {"id": "provider-1", "api_base": ""} - provider.get_model.return_value = "test-model" - provider.meta.return_value = SimpleNamespace(type="test") - - final_resp = SimpleNamespace( - completion_text="done", - result_chain=None, - role="assistant", - usage=None, - ) - agent_runner = MagicMock() - agent_runner.done.return_value = True - agent_runner.was_aborted.return_value = False - agent_runner.get_final_llm_resp.return_value = final_resp - agent_runner.run_context = SimpleNamespace(messages=[]) - agent_runner.stats = MagicMock() - agent_runner.stats.to_dict.return_value = {} - agent_runner.provider = provider - - return SimpleNamespace( - agent_runner=agent_runner, - provider_request=SimpleNamespace( - system_prompt="sys", - func_tool=None, - conversation=object(), - tool_calls_result=None, - ), - provider=provider, - reset_coro=None, - ) - - -async def empty_run_agent(*args, **kwargs): - if False: - yield None - - -@pytest.mark.asyncio -async def test_process_swallows_send_typing_error_and_still_releases(stage, event): - event.send_typing = AsyncMock(side_effect=RuntimeError("boom")) - event.stop_typing = AsyncMock() - obj = await stage() - - with ( - patch.object(internal_module.logger, "warning") as warning_mock, - patch.object(internal_module, "try_capture_follow_up", return_value=None), - patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), - patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), - patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), - ): - results = [item async for item in obj.process(event, provider_wake_prefix="")] - - assert results == [] - event.stop_typing.assert_awaited_once() - warning_mock.assert_called_once_with("send_typing failed", exc_info=True) - - -@pytest.mark.asyncio -async def test_process_releases_typing_when_build_returns_none(stage, event): - event.send_typing = AsyncMock() - event.stop_typing = AsyncMock() - obj = await stage() - - with ( - patch.object(internal_module, "try_capture_follow_up", return_value=None), - patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), - patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), - patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), - ): - results = [item async for item in obj.process(event, provider_wake_prefix="")] - - assert results == [] - event.send_typing.assert_awaited_once() - event.stop_typing.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_process_releases_typing_when_llm_request_hook_short_circuits( - stage, event -): - event.send_typing = AsyncMock() - event.stop_typing = AsyncMock() - obj = await stage() - build_result = make_build_result() - - with ( - patch.object(internal_module, "try_capture_follow_up", return_value=None), - patch.object( - internal_module, - "call_event_hook", - AsyncMock(side_effect=[False, True]), - ), - patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), - patch.object( - internal_module, - "build_main_agent", - AsyncMock(return_value=build_result), - ), - ): - results = [item async for item in obj.process(event, provider_wake_prefix="")] - - assert results == [] - event.stop_typing.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_process_releases_typing_after_normal_reply(stage, event): - event.send_typing = AsyncMock() - event.stop_typing = AsyncMock() - obj = await stage() - build_result = make_build_result() - - with ( - patch.object(internal_module, "try_capture_follow_up", return_value=None), - patch.object( - internal_module, - "call_event_hook", - AsyncMock(side_effect=[False, False]), - ), - patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), - patch.object( - internal_module, - "build_main_agent", - AsyncMock(return_value=build_result), - ), - patch.object(internal_module, "run_agent", empty_run_agent), - patch.object(internal_module, "register_active_runner"), - patch.object(internal_module, "unregister_active_runner"), - ): - results = [item async for item in obj.process(event, provider_wake_prefix="")] - - assert results == [] - event.stop_typing.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_process_does_not_stop_typing_early_for_streaming_platforms(stage, event): - event.platform_meta.support_streaming_message = True - event.send_typing = AsyncMock() - event.stop_typing = AsyncMock() - obj = await stage() - obj.streaming_response = True - build_result = make_build_result() - - with ( - patch.object(internal_module, "try_capture_follow_up", return_value=None), - patch.object( - internal_module, - "call_event_hook", - AsyncMock(side_effect=[False, False]), - ), - patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), - patch.object( - internal_module, - "build_main_agent", - AsyncMock(return_value=build_result), - ), - patch.object(internal_module, "run_agent", empty_run_agent), - patch.object(internal_module, "register_active_runner"), - patch.object(internal_module, "unregister_active_runner"), - ): - results = [item async for item in obj.process(event, provider_wake_prefix="")] - - assert len(results) == 1 - event.stop_typing.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_process_releases_typing_on_error_fallback_send(stage, event): - event.send_typing = AsyncMock() - event.stop_typing = AsyncMock() - event.send = AsyncMock() - obj = await stage() - - with ( - patch.object(internal_module, "try_capture_follow_up", return_value=None), - patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), - patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), - patch.object( - internal_module, - "build_main_agent", - AsyncMock(side_effect=RuntimeError("boom")), - ), - ): - results = [item async for item in obj.process(event, provider_wake_prefix="")] - - assert results == [] - event.send.assert_awaited_once() - event.stop_typing.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_process_swallows_stop_typing_error(stage, event): - event.send_typing = AsyncMock() - event.stop_typing = AsyncMock(side_effect=RuntimeError("stop failed")) - obj = await stage() - - with ( - patch.object(internal_module.logger, "warning") as warning_mock, - patch.object(internal_module, "try_capture_follow_up", return_value=None), - patch.object(internal_module, "call_event_hook", AsyncMock(return_value=False)), - patch.object(internal_module.session_lock_manager, "acquire_lock", fake_lock), - patch.object(internal_module, "build_main_agent", AsyncMock(return_value=None)), - ): - results = [item async for item in obj.process(event, provider_wake_prefix="")] - - assert results == [] - event.send_typing.assert_awaited_once() - event.stop_typing.assert_awaited_once() - warning_mock.assert_called_once_with("stop_typing failed", exc_info=True) diff --git a/tests/unit/test_weixin_oc_typing.py b/tests/unit/test_weixin_oc_typing.py deleted file mode 100644 index 26a270ff32..0000000000 --- a/tests/unit/test_weixin_oc_typing.py +++ /dev/null @@ -1,547 +0,0 @@ -from __future__ import annotations - -import asyncio -from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from astrbot.core.message.components import Plain -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember -from astrbot.core.platform.message_type import MessageType -from astrbot.core.platform.platform_metadata import PlatformMetadata -from astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter import ( - TypingSessionState, - WeixinOCAdapter, -) -from astrbot.core.platform.sources.weixin_oc.weixin_oc_client import WeixinOCClient -from astrbot.core.platform.sources.weixin_oc.weixin_oc_event import WeixinOCMessageEvent - - -@pytest.fixture -def client(): - return WeixinOCClient( - adapter_id="wx-1", - base_url="https://example.com", - cdn_base_url="https://cdn.example.com", - api_timeout_ms=15000, - token="token-1", - ) - - -@pytest.fixture -def adapter(): - obj = WeixinOCAdapter( - platform_config={ - "id": "wx-1", - "type": "weixin_oc", - "weixin_oc_token": "token-1", - }, - platform_settings={}, - event_queue=asyncio.Queue(), - ) - obj._context_tokens["user-1"] = "ctx-1" - return obj - - -@pytest.fixture -def weixin_event(): - message = AstrBotMessage() - message.type = MessageType.FRIEND_MESSAGE - message.self_id = "bot123" - message.session_id = "user-1" - message.message_id = "msg123" - message.sender = MessageMember(user_id="user-1", nickname="User") - message.message = [Plain(text="hello")] - message.message_str = "hello" - message.raw_message = None - - platform = MagicMock() - platform.start_typing = AsyncMock() - platform.stop_typing = AsyncMock() - platform.send_by_session = AsyncMock() - - event = WeixinOCMessageEvent( - message_str="hello", - message_obj=message, - platform_meta=PlatformMetadata( - name="weixin_oc", - description="个人微信", - id="wx-1", - support_streaming_message=False, - ), - session_id="user-1", - platform=platform, - ) - return event, platform - - -@pytest.mark.asyncio -async def test_get_typing_config_uses_getconfig(client): - client.request_json = AsyncMock(return_value={"typing_ticket": "ticket-1"}) - - result = await client.get_typing_config("user-1", "ctx-1") - - assert result == {"typing_ticket": "ticket-1"} - client.request_json.assert_awaited_once_with( - "POST", - "ilink/bot/getconfig", - payload={ - "ilink_user_id": "user-1", - "context_token": "ctx-1", - "base_info": {"channel_version": "astrbot"}, - }, - token_required=True, - timeout_ms=client.api_timeout_ms, - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("cancel, status", [(False, 1), (True, 2)]) -async def test_send_typing_state_uses_sendtyping(client, cancel, status): - client.request_json = AsyncMock(return_value={}) - - await client.send_typing_state("user-1", "ticket-1", cancel=cancel) - - client.request_json.assert_awaited_once_with( - "POST", - "ilink/bot/sendtyping", - payload={ - "ilink_user_id": "user-1", - "typing_ticket": "ticket-1", - "status": status, - "base_info": {"channel_version": "astrbot"}, - }, - token_required=True, - timeout_ms=client.api_timeout_ms, - ) - - -@pytest.mark.asyncio -async def test_event_delegates_typing_calls(weixin_event): - event, platform = weixin_event - - await event.send_typing() - await event.stop_typing() - - platform.start_typing.assert_awaited_once() - platform.stop_typing.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_event_reuses_stable_owner_id(weixin_event): - event, platform = weixin_event - - await event.send_typing() - await event.stop_typing() - - start_owner = platform.start_typing.await_args.args[1] - stop_owner = platform.stop_typing.await_args.args[1] - assert start_owner == stop_owner - - -@pytest.mark.asyncio -async def test_start_typing_skips_without_token(adapter): - adapter.token = None - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - adapter._send_typing_state = AsyncMock() - - await adapter.start_typing("user-1", "owner-a") - - adapter._ensure_typing_ticket.assert_not_awaited() - adapter._send_typing_state.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_start_typing_skips_without_context_token(adapter): - adapter._context_tokens.clear() - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - adapter._send_typing_state = AsyncMock() - - await adapter.start_typing("user-1", "owner-a") - - adapter._ensure_typing_ticket.assert_not_awaited() - adapter._send_typing_state.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_ensure_typing_ticket_reuses_fresh_ticket(adapter): - state = TypingSessionState( - ticket="cached-ticket", - ticket_context_token="ctx-1", - refresh_after=float("inf"), - ) - adapter.client.get_typing_config = AsyncMock() - - result = await adapter._ensure_typing_ticket("user-1", state) - - assert result == "cached-ticket" - adapter.client.get_typing_config.assert_not_awaited() - - -@pytest.mark.asyncio -async def test_ensure_typing_ticket_refreshes_stale_ticket(adapter): - state = TypingSessionState(ticket="stale-ticket", refresh_after=0.0) - adapter.client.get_typing_config = AsyncMock( - return_value={"typing_ticket": "fresh-ticket"} - ) - - result = await adapter._ensure_typing_ticket("user-1", state) - - assert result == "fresh-ticket" - assert state.ticket == "fresh-ticket" - adapter.client.get_typing_config.assert_awaited_once_with("user-1", "ctx-1") - - -@pytest.mark.asyncio -async def test_ensure_typing_ticket_refreshes_when_context_token_changes(adapter): - state = TypingSessionState( - ticket="cached-ticket", - ticket_context_token="ctx-1", - refresh_after=float("inf"), - ) - adapter._context_tokens["user-1"] = "ctx-2" - adapter.client.get_typing_config = AsyncMock( - return_value={"typing_ticket": "fresh-ticket"} - ) - - result = await adapter._ensure_typing_ticket("user-1", state) - - assert result == "fresh-ticket" - assert state.ticket_context_token == "ctx-2" - adapter.client.get_typing_config.assert_awaited_once_with("user-1", "ctx-2") - - -@pytest.mark.asyncio -async def test_send_typing_state_raises_on_nonzero_ret(adapter): - adapter.client.send_typing_state = AsyncMock( - return_value={"ret": 1, "errmsg": "expired"} - ) - - with pytest.raises(RuntimeError, match="sendtyping failed"): - await adapter._send_typing_state("user-1", "ticket-1", cancel=False) - - -@pytest.mark.asyncio -async def test_cancel_task_safely_logs_task_errors(adapter): - async def failing_task(): - try: - await asyncio.Event().wait() - except asyncio.CancelledError as exc: - raise RuntimeError("task wait failed") from exc - - task = asyncio.create_task(failing_task()) - await asyncio.sleep(0) - - with patch( - "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.logger.warning" - ) as warning_mock: - await adapter._cancel_task_safely( - task, - log_message="weixin_oc(%s): typing cleanup failed", - log_args=(adapter.meta().id,), - ) - - warning_mock.assert_called_once_with( - "weixin_oc(%s): typing cleanup failed", - adapter.meta().id, - exc_info=True, - ) - - -@pytest.mark.asyncio -async def test_start_typing_same_owner_is_idempotent(adapter): - stop_event = asyncio.Event() - adapter._send_typing_state = AsyncMock() - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - - async def fake_keepalive(_user_id): - await stop_event.wait() - - adapter._typing_keepalive_loop = fake_keepalive - - await adapter.start_typing("user-1", "owner-a") - await adapter.start_typing("user-1", "owner-a") - - assert adapter._send_typing_state.await_count == 1 - state = adapter._typing_states["user-1"] - assert state.owners == {"owner-a"} - - await adapter.stop_typing("user-1", "owner-a") - stop_event.set() - - -@pytest.mark.asyncio -async def test_stop_typing_only_cancels_on_last_owner(adapter): - stop_event = asyncio.Event() - adapter._send_typing_state = AsyncMock() - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - - async def fake_keepalive(_user_id): - await stop_event.wait() - - adapter._typing_keepalive_loop = fake_keepalive - - await adapter.start_typing("user-1", "owner-a") - await adapter.start_typing("user-1", "owner-b") - await adapter.stop_typing("user-1", "owner-a") - - state = adapter._typing_states["user-1"] - assert state.owners == {"owner-b"} - assert adapter._send_typing_state.await_count == 1 - - await adapter.stop_typing("user-1", "owner-b") - await asyncio.sleep(0) - await asyncio.sleep(0) - stop_event.set() - assert adapter._send_typing_state.await_count == 2 - - -@pytest.mark.asyncio -async def test_stop_typing_is_safe_to_repeat(adapter): - adapter._send_typing_state = AsyncMock() - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - - async def fake_keepalive(_user_id): - await asyncio.Event().wait() - - adapter._typing_keepalive_loop = fake_keepalive - - await adapter.start_typing("user-1", "owner-a") - await adapter.stop_typing("user-1", "owner-a") - await adapter.stop_typing("user-1", "owner-a") - await asyncio.sleep(0) - await asyncio.sleep(0) - - assert adapter._send_typing_state.await_count == 2 - - -@pytest.mark.asyncio -async def test_keepalive_failure_cleans_state(adapter): - adapter._send_typing_state = AsyncMock() - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - - async def fake_keepalive(_user_id): - raise RuntimeError("keepalive failed") - - adapter._typing_keepalive_loop = fake_keepalive - - await adapter.start_typing("user-1", "owner-a") - await asyncio.sleep(0) - - state = adapter._typing_states["user-1"] - assert state.keepalive_task is None - - await adapter.stop_typing("user-1", "owner-a") - await asyncio.sleep(0) - await asyncio.sleep(0) - - assert adapter._send_typing_state.await_count == 2 - - -@pytest.mark.asyncio -async def test_keepalive_failure_restarts_for_active_owner(adapter): - adapter._typing_keepalive_interval_s = 0 - adapter._send_typing_state = AsyncMock() - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - keepalive_round = 0 - stop_event = asyncio.Event() - - async def fake_keepalive(_user_id): - nonlocal keepalive_round - keepalive_round += 1 - if keepalive_round == 1: - raise RuntimeError("keepalive failed") - await stop_event.wait() - - adapter._typing_keepalive_loop = fake_keepalive - - await adapter.start_typing("user-1", "owner-a") - for _ in range(4): - await asyncio.sleep(0) - - state = adapter._typing_states["user-1"] - assert keepalive_round >= 2 - assert state.keepalive_task is not None - - stop_event.set() - await adapter.stop_typing("user-1", "owner-a") - for _ in range(2): - await asyncio.sleep(0) - - -@pytest.mark.asyncio -async def test_stop_typing_does_not_cancel_new_owner_session(adapter): - cancel_blocked = asyncio.Event() - allow_cancel_exit = asyncio.Event() - adapter._send_typing_state = AsyncMock() - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - - async def fake_keepalive(_user_id): - try: - await asyncio.Event().wait() - except asyncio.CancelledError: - cancel_blocked.set() - await allow_cancel_exit.wait() - raise - - adapter._typing_keepalive_loop = fake_keepalive - - await adapter.start_typing("user-1", "owner-a") - stop_task = asyncio.create_task(adapter.stop_typing("user-1", "owner-a")) - await cancel_blocked.wait() - await adapter.start_typing("user-1", "owner-b") - allow_cancel_exit.set() - await stop_task - - assert adapter._send_typing_state.await_count == 2 - - -@pytest.mark.asyncio -async def test_start_typing_cancels_inflight_cancel_task(adapter): - cancel_started = asyncio.Event() - release_cancel = asyncio.Event() - stop_event = asyncio.Event() - events: list[str] = [] - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - - async def fake_send_typing_state(_user_id, ticket, *, cancel): - if cancel: - events.append("cancel-start") - cancel_started.set() - try: - await release_cancel.wait() - except asyncio.CancelledError: - events.append("cancel-cancelled") - raise - events.append("cancel-finished") - return - events.append(f"start-{ticket}") - - async def fake_keepalive(_user_id): - await stop_event.wait() - - adapter._send_typing_state = fake_send_typing_state - adapter._typing_keepalive_loop = fake_keepalive - - await adapter.start_typing("user-1", "owner-a") - await adapter.stop_typing("user-1", "owner-a") - await asyncio.sleep(0) - await asyncio.sleep(0) - await cancel_started.wait() - - start_task = asyncio.create_task(adapter.start_typing("user-1", "owner-b")) - await asyncio.sleep(0) - release_cancel.set() - await start_task - - assert "cancel-cancelled" in events - assert "cancel-finished" not in events - - stop_event.set() - await adapter.stop_typing("user-1", "owner-b") - await asyncio.sleep(0) - await asyncio.sleep(0) - - -@pytest.mark.asyncio -async def test_start_typing_logs_ignored_cancel_task_errors(adapter): - stop_event = asyncio.Event() - adapter._ensure_typing_ticket = AsyncMock(return_value="ticket-1") - state = adapter._get_typing_state("user-1") - - async def fake_send_typing_state(_user_id, _ticket, *, cancel): - return None - - async def fake_cancel_task(): - try: - await asyncio.Event().wait() - except asyncio.CancelledError as exc: - raise RuntimeError("cancel failed") from exc - - async def fake_keepalive(_user_id): - await stop_event.wait() - - adapter._send_typing_state = fake_send_typing_state - adapter._typing_keepalive_loop = fake_keepalive - state.cancel_task = asyncio.create_task(fake_cancel_task()) - await asyncio.sleep(0) - - with patch( - "astrbot.core.platform.sources.weixin_oc.weixin_oc_adapter.logger.warning" - ) as warning_mock: - await adapter.start_typing("user-1", "owner-a") - - warning_mock.assert_called_once_with( - "weixin_oc(%s): ignored error from cancelled typing task", - adapter.meta().id, - exc_info=True, - ) - - stop_event.set() - await adapter.stop_typing("user-1", "owner-a") - await asyncio.sleep(0) - await asyncio.sleep(0) - - -@pytest.mark.asyncio -async def test_cleanup_typing_tasks_sends_final_cancel(adapter): - adapter._send_typing_state = AsyncMock() - - async def fake_keepalive(_user_id): - await asyncio.Event().wait() - - task = asyncio.create_task(fake_keepalive("user-1")) - adapter._typing_states["user-1"] = TypingSessionState( - ticket="ticket-1", - refresh_after=float("inf"), - keepalive_task=task, - owners={"owner-a"}, - ) - - await adapter._cleanup_typing_tasks() - - adapter._send_typing_state.assert_awaited_once_with( - "user-1", - "ticket-1", - cancel=True, - ) - - -@pytest.mark.asyncio -async def test_run_finally_cancels_keepalive_before_client_close(adapter): - order: list[str] = [] - task = asyncio.create_task(asyncio.Event().wait()) - adapter._typing_states["user-1"] = TypingSessionState( - ticket="ticket-1", - refresh_after=float("inf"), - keepalive_task=task, - owners={"owner-a"}, - ) - adapter._cleanup_typing_tasks = AsyncMock( - side_effect=lambda: order.append("cleanup") - ) - adapter.client.close = AsyncMock(side_effect=lambda: order.append("close")) - - with patch.object( - adapter, - "_poll_inbound_updates", - AsyncMock(side_effect=RuntimeError("boom")), - ): - await adapter.run() - - assert order == ["cleanup", "close"] - - -@pytest.mark.asyncio -async def test_send_still_works_with_existing_event_behavior(weixin_event): - event, platform = weixin_event - - with patch( - "astrbot.core.platform.astr_message_event.Metric.upload", - new_callable=AsyncMock, - ): - await event.send(MessageChain([Plain("reply")])) - - platform.send_by_session.assert_awaited_once() From e95fd3d28bdca9ed496ba4bddeb09f226f5c5df5 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Fri, 27 Mar 2026 15:53:36 +0800 Subject: [PATCH 6/6] fix: remove unnecessary platform check for stopping typing --- .../pipeline/process_stage/method/agent_sub_stages/internal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 651721ec20..c7441d09f4 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -382,7 +382,7 @@ async def process( ) await event.send(MessageChain().message(error_text)) finally: - if typing_requested and not event.platform_meta.support_streaming_message: + if typing_requested: try: await event.stop_typing() except Exception: