/
__init__.py
217 lines (186 loc) · 7.59 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""CQHTTP 协议适配器。
本适配器适配了 OneBot v11 协议。
协议详情请参考: [OneBot](https://github.com/howmanybots/onebot/blob/master/README.md) 。
"""
import sys
import json
import time
import asyncio
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, Literal
import aiohttp
from alicebot.utils import DataclassEncoder
from alicebot.adapter.utils import WebSocketAdapter
from alicebot.log import logger, error_or_exception
from .config import Config
from .event import get_event_class
from .message import CQHTTPMessage
from .exceptions import ApiTimeout, ActionFailed, NetworkError, ApiNotAvailable
if TYPE_CHECKING:
from .message import T_CQMSG
__all__ = ["CQHTTPAdapter"]
class CQHTTPAdapter(WebSocketAdapter):
name = "cqhttp"
_api_response: Dict[Any, Any]
_api_response_cond: asyncio.Condition = None
_api_id: int = 0
@property
def config(self):
"""本适配器的配置。"""
return getattr(self.bot.config, Config.__config_name__)
def __getattr__(self, item):
return partial(self.call_api, item)
async def startup(self):
"""初始化适配器。"""
self.adapter_type = self.config.adapter_type
if self.adapter_type == "ws-reverse":
self.adapter_type = "reverse-ws"
self.host = self.config.host
self.port = self.config.port
self.url = self.config.url
self.reconnect_interval = self.config.reconnect_interval
self._api_response_cond = asyncio.Condition()
await super().startup()
async def reverse_ws_connection_hook(self):
"""反向 WebSocket 连接建立时的钩子函数。"""
logger.info(f"WebSocket connected!")
if self.config.access_token:
if (
self.websocket.headers.get("Authorization", "")
!= f"Bearer {self.config.access_token}"
):
await self.websocket.close()
async def websocket_connect(self):
"""创建正向 WebSocket 连接。"""
logger.info("Tying to connect to WebSocket server...")
async with self.session.ws_connect(
f"ws://{self.host}:{self.port}/",
headers={"Authorization": f"Bearer {self.config.access_token}"}
if self.config.access_token
else None,
) as self.websocket:
await self.handle_websocket()
async def handle_websocket_msg(self, msg: aiohttp.WSMessage):
"""处理 WebSocket 消息。"""
if msg.type == aiohttp.WSMsgType.TEXT:
try:
msg_dict = msg.json()
except json.JSONDecodeError as e:
error_or_exception(
"WebSocket message parsing error, not json:",
e,
self.bot.config.verbose_exception_log,
)
return
if "post_type" in msg_dict:
await self.handle_cqhttp_event(msg_dict)
else:
async with self._api_response_cond:
self._api_response = msg_dict
self._api_response_cond.notify_all()
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(
f"WebSocket connection closed "
f"with exception {self.websocket.exception()!r}"
)
def _get_api_echo(self) -> int:
self._api_id = (self._api_id + 1) % sys.maxsize
return self._api_id
async def handle_cqhttp_event(self, msg: Dict[str, Any]):
"""处理 CQHTTP 事件。
Args:
msg: 接收到的信息。
"""
post_type = msg.get("post_type")
event_type = msg.get(post_type + "_type")
sub_type = msg.get("sub_type", None)
event_class = get_event_class(post_type, event_type, sub_type)
cqhttp_event = event_class(adapter=self, **msg)
if cqhttp_event.post_type == "meta_event":
# meta_event 不交由插件处理
if (
cqhttp_event.meta_event_type == "lifecycle"
and cqhttp_event.sub_type == "connect"
):
logger.info(
f"WebSocket connection "
f"from CQHTTP Bot {msg.get('self_id')} accepted!"
)
elif cqhttp_event.meta_event_type == "heartbeat":
if cqhttp_event.status.good and cqhttp_event.status.online:
pass
else:
logger.error(
f"CQHTTP Bot status is not good: {cqhttp_event.status.dict()}"
)
else:
await self.handle_event(cqhttp_event)
async def call_api(self, api: str, **params) -> Dict[str, Any]:
"""调用 CQHTTP API,协程会等待直到获得 API 响应。
Args:
api: API 名称。
**params: API 参数。
Returns:
API 响应中的 data 字段。
Raises:
NetworkError: 网络错误。
ApiNotAvailable: API 请求响应 404, API 不可用。
ActionFailed: API 请求响应 failed, API 操作失败。
ApiTimeout: API 请求响应超时。
"""
api_echo = self._get_api_echo()
try:
await self.websocket.send_str(
json.dumps(
{"action": api, "params": params, "echo": api_echo},
cls=DataclassEncoder,
)
)
except Exception:
raise NetworkError
start_time = time.time()
while not self.bot.should_exit.is_set():
if time.time() - start_time > self.config.api_timeout:
break
async with self._api_response_cond:
try:
await asyncio.wait_for(
self._api_response_cond.wait(),
timeout=start_time + self.config.api_timeout - time.time(),
)
except asyncio.TimeoutError:
break
if self._api_response["echo"] == api_echo:
if self._api_response.get("retcode") == 1404:
raise ApiNotAvailable(resp=self._api_response)
if self._api_response.get("status") == "failed":
raise ActionFailed(resp=self._api_response)
return self._api_response.get("data")
if not self.bot.should_exit.is_set():
raise ApiTimeout
async def send(
self, message_: "T_CQMSG", message_type: Literal["private", "group"], id_: int
) -> Dict[str, Any]:
"""发送消息,调用 send_private_msg 或 send_group_msg API 发送消息。
Args:
message_: 消息内容,可以是 str, Mapping, Iterable[Mapping],
'CQHTTPMessageSegment', 'CQHTTPMessage'。
将使用 `CQHTTPMessage` 进行封装。
message_type: 消息类型。应该是 private 或者 group。
id_: 发送对象的 ID ,QQ 号码或者群号码。
Returns:
API 响应。
Raises:
TypeError: message_type 不是 'private' 或 'group'。
...: 同 `call_api()` 方法。
"""
if message_type == "private":
return await self.send_private_msg(
user_id=id_, message=CQHTTPMessage(message_)
)
elif message_type == "group":
return await self.send_group_msg(
group_id=id_, message=CQHTTPMessage(message_)
)
else:
raise TypeError('message_type must be "private" or "group"')