Skip to content

Commit

Permalink
feat(utils): 删除 utils.Condition 类,改用标准库中的 asyncio.Condition 类
Browse files Browse the repository at this point in the history
BREAKING CHANGE: 删除 utils.Condition 类
  • Loading branch information
st1020 committed Aug 21, 2022
1 parent c008711 commit 3e9db43
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 121 deletions.
24 changes: 13 additions & 11 deletions alicebot/__init__.py
Expand Up @@ -41,7 +41,6 @@
LoadModuleError,
)
from alicebot.utils import (
Condition,
ModuleInfo,
ModulePathFinder,
samefile,
Expand Down Expand Up @@ -79,7 +78,8 @@ class Bot:
plugin_state: Dict[str, Any]
global_state: Dict[Any, Any]

_condition: Condition[T_Event]
_condition: asyncio.Condition
_current_event: T_Event

_restart_flag: bool
_module_path_finder: ModulePathFinder
Expand Down Expand Up @@ -179,7 +179,7 @@ def restart(self):
async def _run(self):
"""运行 AliceBot。"""
self.should_exit = asyncio.Event()
self._condition = Condition()
self._condition = asyncio.Condition()

# 监听并拦截系统退出信号,从而完成一些善后工作后再关闭程序
if threading.current_thread() is threading.main_thread():
Expand Down Expand Up @@ -425,14 +425,16 @@ async def handle_event(
asyncio.create_task(self._handle_event())
await asyncio.sleep(0)
async with self._condition:
self._condition.notify_all(current_event)
self._current_event = current_event
self._condition.notify_all()
else:
asyncio.create_task(self._handle_event(current_event))

async def _handle_event(self, current_event: Optional[T_Event] = None):
if current_event is None:
async with self._condition:
current_event = await self._condition.wait()
await self._condition.wait()
current_event = self._current_event
if current_event.__handled__:
return

Expand Down Expand Up @@ -529,20 +531,20 @@ async def _wrapper(_event):

async with self._condition:
if timeout is None:
event = await self._condition.wait()
await self._condition.wait()
else:
try:
event = await asyncio.wait_for(
await asyncio.wait_for(
self._condition.wait(),
timeout=start_time + timeout - time.time(),
)
except asyncio.TimeoutError:
break

if not event.__handled__:
if await func(event):
event.__handled__ = True
return event
if not self._current_event.__handled__:
if await func(self._current_event):
self._current_event.__handled__ = True
return self._current_event

try_times += 1

Expand Down
85 changes: 1 addition & 84 deletions alicebot/utils.py
@@ -1,24 +1,21 @@
import os
import json
import asyncio
import inspect
import os.path
import pkgutil
import importlib
import collections
import dataclasses
from abc import ABC
from types import ModuleType
from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder
from typing import List, Type, Generic, TypeVar, Callable, Iterable, Optional
from typing import List, Type, Generic, TypeVar, Iterable, Optional

from pydantic import BaseModel

from alicebot.exceptions import LoadModuleError

__all__ = [
"Condition",
"ModulePathFinder",
"ModuleInfo",
"load_module",
Expand All @@ -32,86 +29,6 @@
_T = TypeVar("_T")


class Condition(Generic[_T]):
"""类似于 asyncio.Condition ,但允许在 notify() 时传递值,并由 wait() 返回。"""

def __init__(self):
self._loop = asyncio.get_running_loop()
lock = asyncio.Lock()
self._lock = lock
# Export the lock's locked(), acquire() and release() methods.
self.locked = lock.locked
self.acquire = lock.acquire
self.release = lock.release

self._waiters = collections.deque()

async def __aenter__(self):
await self.acquire()
# We have no use for the "as ..." clause in the with
# statement for locks.
return None

async def __aexit__(self, exc_type, exc, tb):
self.release()

def __repr__(self):
res = super().__repr__()
extra = "locked" if self.locked() else "unlocked"
if self._waiters:
extra = f"{extra}, waiters:{len(self._waiters)}"
return f"<{res[1:-1]} [{extra}]>"

async def wait(self) -> _T:
if not self.locked():
raise RuntimeError("cannot wait on un-acquired lock")

self.release()
try:
fut = self._loop.create_future()
self._waiters.append(fut)
try:
return await fut
finally:
self._waiters.remove(fut)

finally:
# Must reacquire lock even if wait is cancelled
cancelled = False
while True:
try:
await self.acquire()
break
except asyncio.CancelledError:
cancelled = True

if cancelled:
raise asyncio.CancelledError

async def wait_for(self, predicate: Callable[..., bool]) -> bool:
result = predicate()
while not result:
await self.wait()
result = predicate()
return result

def notify(self, value: _T = None, n: int = 1):
if not self.locked():
raise RuntimeError("cannot notify on un-acquired lock")

idx = 0
for fut in self._waiters:
if idx >= n:
break

if not fut.done():
idx += 1
fut.set_result(value)

def notify_all(self, value: _T = None):
self.notify(value, len(self._waiters))


class ModulePathFinder(MetaPathFinder):
"""用于查找 AliceBot 组件的元路径查找器。"""

Expand Down
Expand Up @@ -12,9 +12,9 @@

import aiohttp

from alicebot.utils import DataclassEncoder
from alicebot.adapter.utils import WebSocketAdapter
from alicebot.log import logger, error_or_exception
from alicebot.utils import Condition, DataclassEncoder

from .config import Config
from .event import get_event_class
Expand All @@ -29,7 +29,8 @@

class CQHTTPAdapter(WebSocketAdapter):
name = "cqhttp"
api_response_cond: Condition = None
_api_response: Dict[Any, Any]
_api_response_cond: asyncio.Condition = None
_api_id: int = 0

@property
Expand All @@ -49,7 +50,7 @@ async def startup(self):
self.port = self.config.port
self.url = self.config.url
self.reconnect_interval = self.config.reconnect_interval
self.api_response_cond = Condition()
self._api_response_cond = asyncio.Condition()
await super().startup()

async def reverse_ws_connection_hook(self):
Expand Down Expand Up @@ -89,8 +90,9 @@ async def handle_websocket_msg(self, msg: aiohttp.WSMessage):
if "post_type" in msg_dict:
await self.handle_cqhttp_event(msg_dict)
else:
async with self.api_response_cond:
self.api_response_cond.notify_all(msg_dict)
async with self._api_response_cond:
self._api_response = msg_dict
self._api_response_cond.notify_all()

elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(
Expand Down Expand Up @@ -166,20 +168,20 @@ async def call_api(self, api: str, **params) -> Dict[str, Any]:
while not self.bot.should_exit.is_set():
if time.time() - start_time > self.config.api_timeout:
break
async with self.api_response_cond:
async with self._api_response_cond:
try:
resp = await asyncio.wait_for(
self.api_response_cond.wait(),
await asyncio.wait_for(
self._api_response_cond.wait(),
timeout=start_time + self.config.api_timeout - time.time(),
)
except asyncio.TimeoutError:
break
if resp["echo"] == api_echo:
if resp.get("retcode") == 1404:
raise ApiNotAvailable(resp=resp)
if resp.get("status") == "failed":
raise ActionFailed(resp=resp)
return resp.get("data")
if self._api_response["echo"] == api_echo:
if self._api_response.get("retcode") == 1404:
raise ApiNotAvailable(resp=self._api_response)
if self._api_response.get("status") == "failed":
raise ActionFailed(resp=self._api_response)
return self._api_response.get("data")

if not self.bot.should_exit.is_set():
raise ApiTimeout
Expand Down
26 changes: 14 additions & 12 deletions packages/alicebot-adapter-mirai/alicebot/adapter/mirai/__init__.py
Expand Up @@ -13,9 +13,9 @@

import aiohttp

from alicebot.utils import DataclassEncoder
from alicebot.adapter.utils import WebSocketAdapter
from alicebot.log import logger, error_or_exception
from alicebot.utils import Condition, DataclassEncoder

from .config import Config
from .message import MiraiMessage
Expand All @@ -37,7 +37,8 @@ class MiraiAdapter(WebSocketAdapter):
"""

name: str = "mirai"
api_response_cond: Condition = None
_api_response: Any
_api_response_cond: asyncio.Condition = None
_sync_id: int = 0

@property
Expand All @@ -55,7 +56,7 @@ async def startup(self):
self.port = self.config.port
self.url = self.config.url
self.reconnect_interval = self.config.reconnect_interval
self.api_response_cond = Condition()
self._api_response_cond = asyncio.Condition()
await super().startup()

async def reverse_ws_connection_hook(self):
Expand Down Expand Up @@ -100,8 +101,9 @@ async def handle_websocket_msg(self, msg: aiohttp.WSMessage):
elif msg_dict.get("syncId") == "-1":
await self.handle_mirai_event(msg_dict.get("data"))
else:
async with self.api_response_cond:
self.api_response_cond.notify_all(msg_dict)
async with self._api_response_cond:
self._api_response = msg_dict
self._api_response_cond.notify_all()

elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(
Expand Down Expand Up @@ -190,19 +192,19 @@ async def call_api(
while not self.bot.should_exit.is_set():
if time.time() - start_time > self.config.api_timeout:
break
async with self.api_response_cond:
async with self._api_response_cond:
try:
resp = await asyncio.wait_for(
self.api_response_cond.wait(),
await asyncio.wait_for(
self._api_response_cond.wait(),
timeout=start_time + self.config.api_timeout - time.time(),
)
except asyncio.TimeoutError:
break
if resp.get("syncId") == sync_id:
status_code = resp.get("data", {}).get("code")
if self._api_response.get("syncId") == sync_id:
status_code = self._api_response.get("data", {}).get("code")
if status_code is not None and status_code != 0:
raise ActionFailed(code=status_code, resp=resp)
return resp.get("data")
raise ActionFailed(code=status_code, resp=self._api_response)
return self._api_response.get("data")

if not self.bot.should_exit.is_set():
raise ApiTimeout
Expand Down

0 comments on commit 3e9db43

Please sign in to comment.