diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..c7441d09f4 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -144,6 +144,7 @@ async def process( follow_up_capture: FollowUpCapture | None = None follow_up_consumed_marked = False follow_up_activated = False + typing_requested = False try: streaming_response = self.streaming_response if (enable_streaming := event.get_extra("enable_streaming")) is not None: @@ -178,7 +179,11 @@ async def process( ) return - await event.send_typing() + try: + typing_requested = True + await event.send_typing() + except Exception: + logger.warning("send_typing failed", exc_info=True) await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) async with session_lock_manager.acquire_lock(event.unified_msg_origin): @@ -377,6 +382,11 @@ async def process( ) await event.send(MessageChain().message(error_text)) finally: + if typing_requested: + try: + await event.stop_typing() + except Exception: + logger.warning("stop_typing failed", exc_info=True) if follow_up_capture: await finalize_follow_up_capture( follow_up_capture, diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 82c03dbb0d..0ecd47fedc 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -293,6 +293,12 @@ async def send_typing(self) -> None: 默认实现为空,由具体平台按需重写。 """ + async def stop_typing(self) -> None: + """停止输入中状态。 + + 默认实现为空,由具体平台按需重写。 + """ + async def _pre_send(self) -> None: """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18""" diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py index c47b58087e..67189a25da 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_adapter.py @@ -6,7 +6,7 @@ import io import time import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any, cast from urllib.parse import quote @@ -49,6 +49,17 @@ class OpenClawLoginSession: error: str | None = None +@dataclass +class TypingSessionState: + ticket: str | None = None + ticket_context_token: str | None = None + refresh_after: float = 0.0 + keepalive_task: asyncio.Task | None = None + cancel_task: asyncio.Task | None = None + owners: set[str] = field(default_factory=set) + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + + @register_platform_adapter( "weixin_oc", "个人微信", @@ -105,7 +116,16 @@ def __init__( self._sync_buf = "" self._qr_expired_count = 0 self._context_tokens: dict[str, str] = {} + self._typing_states: dict[str, TypingSessionState] = {} self._last_inbound_error = "" + self._typing_keepalive_interval_s = max( + 1, + int(platform_config.get("weixin_oc_typing_keepalive_interval", 5)), + ) + self._typing_ticket_ttl_s = max( + 5, + int(platform_config.get("weixin_oc_typing_ticket_ttl", 60)), + ) self.token = str(platform_config.get("weixin_oc_token", "")).strip() or None self.account_id = ( @@ -132,6 +152,316 @@ def _sync_client_state(self) -> None: self.client.api_timeout_ms = self.api_timeout_ms self.client.token = self.token + def _get_typing_state(self, user_id: str) -> TypingSessionState: + state = self._typing_states.get(user_id) + if state is None: + state = TypingSessionState() + self._typing_states[user_id] = state + return state + + def _typing_supported_for(self, user_id: str) -> bool: + if not self.token: + return False + return bool(self._context_tokens.get(user_id)) + + async def _cancel_task_safely( + self, + task: asyncio.Task | None, + *, + log_message: str | None = None, + log_args: tuple[Any, ...] = (), + ) -> None: + if task is None or task.done(): + return + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + if log_message is not None: + logger.warning(log_message, *log_args, exc_info=True) + + async def _ensure_typing_ticket( + self, + user_id: str, + state: TypingSessionState, + ) -> str | None: + now = time.monotonic() + context_token = self._context_tokens.get(user_id) + if not context_token: + return None + + if ( + state.ticket + and state.ticket_context_token == context_token + and state.refresh_after > now + ): + return state.ticket + + payload = await self.client.get_typing_config(user_id, context_token) + if int(payload.get("ret") or 0) != 0: + logger.warning( + "weixin_oc(%s): getconfig failed for %s: %s", + self.meta().id, + user_id, + payload.get("errmsg", ""), + ) + return None + + ticket = str(payload.get("typing_ticket", "")).strip() + if not ticket: + return None + + state.ticket = ticket + state.ticket_context_token = context_token + state.refresh_after = time.monotonic() + self._typing_ticket_ttl_s + return ticket + + async def _send_typing_state( + self, + user_id: str, + ticket: str, + *, + cancel: bool, + ) -> None: + payload = await self.client.send_typing_state(user_id, ticket, cancel=cancel) + if int(payload.get("ret") or 0) != 0: + raise RuntimeError( + f"sendtyping failed for {user_id}: {payload.get('errmsg', '')}" + ) + + async def _run_typing_keepalive(self, user_id: str) -> None: + restart_needed = False + try: + await self._typing_keepalive_loop(user_id) + except asyncio.CancelledError: + raise + except Exception as e: + state = self._typing_states.get(user_id) + if state is not None: + async with state.lock: + state.refresh_after = 0.0 + restart_needed = ( + bool(state.owners) and not self._shutdown_event.is_set() + ) + logger.warning( + "weixin_oc(%s): typing keepalive failed for %s: %s", + self.meta().id, + user_id, + e, + ) + finally: + state = self._typing_states.get(user_id) + current_task = asyncio.current_task() + if state is not None and state.keepalive_task is current_task: + state.keepalive_task = None + + if not restart_needed: + return + + await asyncio.sleep(self._typing_keepalive_interval_s) + state = self._typing_states.get(user_id) + if state is None or self._shutdown_event.is_set(): + return + + async with state.lock: + if not state.owners or state.keepalive_task is not None: + return + state.keepalive_task = asyncio.create_task( + self._run_typing_keepalive(user_id) + ) + + async def _typing_keepalive_loop(self, user_id: str) -> None: + while not self._shutdown_event.is_set(): + await asyncio.sleep(self._typing_keepalive_interval_s) + state = self._typing_states.get(user_id) + if state is None: + return + + async with state.lock: + if not state.owners: + return + try: + ticket = await self._ensure_typing_ticket(user_id, state) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): refresh typing ticket failed for %s: %s", + self.meta().id, + user_id, + e, + ) + continue + if not ticket: + continue + try: + await self._send_typing_state(user_id, ticket, cancel=False) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): typing keepalive send failed for %s: %s", + self.meta().id, + user_id, + e, + ) + + async def _delayed_cancel_typing(self, user_id: str, ticket: str) -> None: + await asyncio.sleep(0) + state = self._typing_states.get(user_id) + if state is None: + return + + current_task = asyncio.current_task() + async with state.lock: + if state.cancel_task is not current_task: + return + if state.owners or state.keepalive_task is not None: + state.cancel_task = None + return + + try: + await self._send_typing_state(user_id, ticket, cancel=True) + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "weixin_oc(%s): cancel typing failed for %s: %s", + self.meta().id, + user_id, + e, + ) + finally: + state = self._typing_states.get(user_id) + if state is None: + return + async with state.lock: + if state.cancel_task is current_task: + state.cancel_task = None + + async def start_typing(self, user_id: str, owner_id: str) -> None: + state = self._get_typing_state(user_id) + cancel_task: asyncio.Task | None = None + async with state.lock: + if owner_id in state.owners: + return + if not self._typing_supported_for(user_id): + return + if state.cancel_task is not None and not state.cancel_task.done(): + cancel_task = state.cancel_task + cancel_task.cancel() + state.cancel_task = None + try: + ticket = await self._ensure_typing_ticket(user_id, state) + except Exception as e: + logger.warning( + "weixin_oc(%s): ensure typing ticket failed for %s: %s", + self.meta().id, + user_id, + e, + ) + return + if not ticket: + return + + state.ticket = ticket + state.owners.add(owner_id) + if state.keepalive_task is not None and not state.keepalive_task.done(): + return + + try: + await self._send_typing_state(user_id, ticket, cancel=False) + except Exception as e: + state.refresh_after = 0.0 + logger.warning( + "weixin_oc(%s): send typing failed for %s: %s", + self.meta().id, + user_id, + e, + ) + + task = asyncio.create_task(self._run_typing_keepalive(user_id)) + state.keepalive_task = task + + if cancel_task is not None: + await self._cancel_task_safely( + cancel_task, + log_message="weixin_oc(%s): ignored error from cancelled typing task", + log_args=(self.meta().id,), + ) + + async def stop_typing(self, user_id: str, owner_id: str) -> None: + state = self._typing_states.get(user_id) + if state is None: + return + + task: asyncio.Task | None = None + async with state.lock: + if owner_id not in state.owners: + return + state.owners.remove(owner_id) + + if state.owners: + return + + task = state.keepalive_task + state.keepalive_task = None + + await self._cancel_task_safely( + task, + log_message="weixin_oc(%s): typing keepalive stop failed for %s", + log_args=(self.meta().id, user_id), + ) + + async with state.lock: + if state.owners: + return + ticket = state.ticket + if ticket: + if state.cancel_task is None or state.cancel_task.done(): + state.cancel_task = asyncio.create_task( + self._delayed_cancel_typing(user_id, ticket) + ) + + async def _cleanup_typing_tasks(self) -> None: + tasks: list[asyncio.Task] = [] + cancels: list[tuple[str, str]] = [] + for user_id, state in list(self._typing_states.items()): + if state.ticket and ( + state.owners + or state.keepalive_task is not None + or state.cancel_task is not None + ): + cancels.append((user_id, state.ticket)) + state.owners.clear() + if state.keepalive_task is not None and not state.keepalive_task.done(): + tasks.append(state.keepalive_task) + state.keepalive_task.cancel() + state.keepalive_task = None + if state.cancel_task is not None and not state.cancel_task.done(): + tasks.append(state.cancel_task) + state.cancel_task.cancel() + state.cancel_task = None + + for task in tasks: + await self._cancel_task_safely( + task, + log_message="weixin_oc(%s): typing cleanup failed", + log_args=(self.meta().id,), + ) + + for user_id, ticket in cancels: + try: + await self._send_typing_state(user_id, ticket, cancel=True) + except Exception as e: + logger.warning( + "weixin_oc(%s): typing cleanup cancel failed for %s: %s", + self.meta().id, + user_id, + e, + ) + def _load_account_state(self) -> None: if not self.token: token = str(self.config.get("weixin_oc_token", "")).strip() @@ -907,10 +1237,12 @@ async def run(self) -> None: except Exception as e: logger.exception("weixin_oc(%s): run failed: %s", self.meta().id, e) finally: + await self._cleanup_typing_tasks() await self.client.close() async def terminate(self) -> None: self._shutdown_event.set() + await self._cleanup_typing_tasks() def get_stats(self) -> dict: stat = super().get_stats() diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py index 5ea30d911c..51b0b6ed7c 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_client.py @@ -226,3 +226,44 @@ async def request_json( if not text: return {} return cast(dict[str, Any], json.loads(text)) + + async def get_typing_config( + self, + user_id: str, + context_token: str, + ) -> dict[str, Any]: + return await self.request_json( + "POST", + "ilink/bot/getconfig", + payload={ + "ilink_user_id": user_id, + "context_token": context_token, + "base_info": { + "channel_version": "astrbot", + }, + }, + token_required=True, + timeout_ms=self.api_timeout_ms, + ) + + async def send_typing_state( + self, + user_id: str, + typing_ticket: str, + *, + cancel: bool, + ) -> dict[str, Any]: + return await self.request_json( + "POST", + "ilink/bot/sendtyping", + payload={ + "ilink_user_id": user_id, + "typing_ticket": typing_ticket, + "status": 2 if cancel else 1, + "base_info": { + "channel_version": "astrbot", + }, + }, + token_required=True, + timeout_ms=self.api_timeout_ms, + ) diff --git a/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py b/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py index abe3b5a066..84a19a9e7b 100644 --- a/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py +++ b/astrbot/core/platform/sources/weixin_oc/weixin_oc_event.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import uuid from typing import TYPE_CHECKING from astrbot.api.event import AstrMessageEvent, MessageChain @@ -29,6 +30,12 @@ def __init__( ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.platform = platform + self._typing_owner_id: str | None = None + + def _get_typing_owner_id(self) -> str: + if not self._typing_owner_id: + self._typing_owner_id = uuid.uuid4().hex + return self._typing_owner_id @staticmethod def _segment_to_text(segment: BaseMessageComponent) -> str: @@ -58,6 +65,18 @@ async def send(self, message: MessageChain) -> None: await self.platform.send_by_session(self.session, message) await super().send(message) + async def send_typing(self) -> None: + await self.platform.start_typing( + self.session.session_id, + self._get_typing_owner_id(), + ) + + async def stop_typing(self) -> None: + await self.platform.stop_typing( + self.session.session_id, + self._get_typing_owner_id(), + ) + async def send_streaming(self, generator, use_fallback: bool = False): if not use_fallback: buffer = None diff --git a/tests/unit/test_astr_message_event.py b/tests/unit/test_astr_message_event.py index ac529318fe..89087d1cab 100644 --- a/tests/unit/test_astr_message_event.py +++ b/tests/unit/test_astr_message_event.py @@ -651,6 +651,15 @@ async def test_send_typing_default_empty(self, astr_message_event): await astr_message_event.send_typing() +class TestStopTyping: + """Tests for stop_typing method.""" + + @pytest.mark.asyncio + async def test_stop_typing_default_empty(self, astr_message_event): + """Test stop_typing default implementation is empty.""" + await astr_message_event.stop_typing() + + class TestReact: """Tests for react method.""" @@ -772,10 +781,12 @@ def test_get_sender_fields_without_sender_attr(self, astr_message_event): def test_get_message_type_with_non_enum_type(self, astr_message_event): """get_message_type should handle message_obj.type that is not a MessageType.""" + class DummyMessage: def __init__(self): self.type = "not_an_enum" self.message = [] + astr_message_event.message_obj = DummyMessage() message_type = astr_message_event.get_message_type() assert isinstance(message_type, MessageType)