Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions astrbot/core/config/astrbot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import enum
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict

ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
logger = logging.getLogger("astrbot")
Expand Down Expand Up @@ -43,7 +42,7 @@ def __init__(
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)

with open(config_path, "r", encoding="utf-8-sig") as f:
with open(config_path, encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith("/ufeff"): # remove BOM
conf_str = conf_str.encode("utf8")[3:].decode("utf8")
Expand Down Expand Up @@ -82,7 +81,7 @@ def _parse_schema(schema: dict, conf: dict):

return conf

def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
def check_config_integrity(self, refer_conf: dict, conf: dict, path=""):
"""检查配置完整性,如果有新的配置项则返回 True"""
has_new = False
for key, value in refer_conf.items():
Expand All @@ -102,7 +101,7 @@ def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
)
return has_new

def save_config(self, replace_config: Dict = None):
def save_config(self, replace_config: dict = None):
"""将配置写入文件

如果传入 replace_config,则将配置替换为 replace_config
Expand Down
7 changes: 3 additions & 4 deletions astrbot/core/conversation_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import json
import asyncio
from astrbot.core import sp
from typing import Dict, List
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation

Expand All @@ -19,7 +18,7 @@ class ConversationManager:

def __init__(self, db_helper: BaseDatabase):
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.session_conversations: dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
Expand Down Expand Up @@ -100,7 +99,7 @@ async def get_conversation(
"""
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)

async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
async def get_conversations(self, unified_msg_origin: str) -> list[Conversation]:
"""获取会话的所有对话

Args:
Expand All @@ -111,7 +110,7 @@ async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]
return self.db.get_conversations(unified_msg_origin)

async def update_conversation(
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
self, unified_msg_origin: str, conversation_id: str, history: list[dict]
):
"""更新会话的对话

Expand Down
5 changes: 2 additions & 3 deletions astrbot/core/core_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
Expand Down Expand Up @@ -115,7 +114,7 @@ async def initialize(self):
self.start_time = int(time.time())

# 初始化当前任务列表
self.curr_tasks: List[asyncio.Task] = []
self.curr_tasks: list[asyncio.Task] = []

# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
Expand Down Expand Up @@ -220,7 +219,7 @@ async def restart(self):
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()

def load_platform(self) -> List[asyncio.Task]:
def load_platform(self) -> list[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()
Expand Down
18 changes: 9 additions & 9 deletions astrbot/core/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def update_llm_history(self, session_id: str, content: str, provider_type: str):
@abc.abstractmethod
def get_llm_history(
self, session_id: str = None, provider_type: str = None
) -> List[LLMHistory]:
) -> list[LLMHistory]:
"""获取 LLM 历史记录, 如果 session_id 为 None, 返回所有"""
raise NotImplementedError

Expand All @@ -73,7 +73,7 @@ def insert_atri_vision_data(self, vision_data: ATRIVision):
raise NotImplementedError

@abc.abstractmethod
def get_atri_vision_data(self) -> List[ATRIVision]:
def get_atri_vision_data(self) -> list[ATRIVision]:
"""获取 ATRI 视觉数据"""
raise NotImplementedError

Expand All @@ -95,7 +95,7 @@ def new_conversation(self, user_id: str, cid: str):
raise NotImplementedError

@abc.abstractmethod
def get_conversations(self, user_id: str) -> List[Conversation]:
def get_conversations(self, user_id: str) -> list[Conversation]:
raise NotImplementedError

@abc.abstractmethod
Expand All @@ -121,7 +121,7 @@ def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
) -> tuple[list[dict[str, Any]], int]:
"""获取所有对话,支持分页

Args:
Expand All @@ -138,12 +138,12 @@ def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
platforms: list[str] = None,
message_types: list[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
exclude_ids: list[str] = None,
exclude_platforms: list[str] = None,
) -> tuple[list[dict[str, Any]], int]:
"""获取筛选后的对话列表

Args:
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/db/plugin/sqlite_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __new__(cls):
"""
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
if cls._instance is None:
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
cls._instance = super().__new__(cls)
cls._instance.db_path = DBPATH
return cls._instance

Expand Down
9 changes: 4 additions & 5 deletions astrbot/core/db/po.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""指标数据"""

from dataclasses import dataclass, field
from typing import List


@dataclass
Expand Down Expand Up @@ -42,9 +41,9 @@ class Command:

@dataclass
class Stats:
platform: List[Platform] = field(default_factory=list)
command: List[Command] = field(default_factory=list)
llm: List[Provider] = field(default_factory=list)
platform: list[Platform] = field(default_factory=list)
command: list[Command] = field(default_factory=list)
llm: list[Provider] = field(default_factory=list)


@dataclass
Expand All @@ -64,7 +63,7 @@ class ATRIVision:
url_or_path: str
caption: str
is_meme: bool
keywords: List[str]
keywords: list[str]
platform_name: str
session_id: str
sender_nickname: str
Expand Down
24 changes: 12 additions & 12 deletions astrbot/core/db/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import time
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
from . import BaseDatabase
from typing import Tuple, List, Dict, Any
from typing import Any


class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
super().__init__()
self.db_path = db_path

with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
with open(os.path.dirname(__file__) + "/sqlite_init.sql") as f:
sql = f.read()

# 初始化数据库
Expand Down Expand Up @@ -56,7 +56,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection:
conn.text_factory = str
return conn

def _exec_sql(self, sql: str, params: Tuple = None):
def _exec_sql(self, sql: str, params: tuple = None):
conn = self.conn
try:
c = self.conn.cursor()
Expand Down Expand Up @@ -122,7 +122,7 @@ def update_llm_history(self, session_id: str, content: str, provider_type: str):

def get_llm_history(
self, session_id: str = None, provider_type: str = None
) -> Tuple:
) -> tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
Expand Down Expand Up @@ -268,7 +268,7 @@ def new_conversation(self, user_id: str, cid: str):
(user_id, cid, history, updated_at, created_at),
)

def get_conversations(self, user_id: str) -> Tuple:
def get_conversations(self, user_id: str) -> tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
Expand Down Expand Up @@ -349,7 +349,7 @@ def insert_atri_vision_data(self, vision: ATRIVision):
),
)

def get_atri_vision_data(self) -> Tuple:
def get_atri_vision_data(self) -> tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
Expand Down Expand Up @@ -391,7 +391,7 @@ def get_atri_vision_data_by_path_or_id(

def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
) -> tuple[list[dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
Expand Down Expand Up @@ -452,12 +452,12 @@ def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
platforms: list[str] = None,
message_types: list[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
exclude_ids: list[str] = None,
exclude_platforms: list[str] = None,
) -> tuple[list[dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
Expand Down
3 changes: 1 addition & 2 deletions astrbot/core/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import sys
from collections import deque
from asyncio import Queue
from typing import List

# 日志缓存大小
CACHED_SIZE = 200
Expand Down Expand Up @@ -87,7 +86,7 @@ class LogBroker:

def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: List[Queue] = [] # 订阅者列表
self.subscribers: list[Queue] = [] # 订阅者列表

def register(self) -> Queue:
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/message/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def toString(self):
k = "type"
if isinstance(v, bool):
v = 1 if v else 0
output += ",%s=%s" % (
output += ",{}={}".format(
k,
str(v)
.replace("&", "&")
Expand Down
15 changes: 8 additions & 7 deletions astrbot/core/message/message_event_result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum

from typing import List, Optional, Union, AsyncGenerator

from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from astrbot.core.message.components import (
BaseMessageComponent,
Expand All @@ -22,8 +23,8 @@ class MessageChain:
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
"""

chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
chain: list[BaseMessageComponent] = field(default_factory=list)
use_t2i_: bool | None = None # None 为跟随用户设置

def message(self, message: str):
"""添加一条文本消息到消息链 `chain` 中。
Expand All @@ -37,7 +38,7 @@ def message(self, message: str):
self.chain.append(Plain(message))
return self

def at(self, name: str, qq: Union[str, int]):
def at(self, name: str, qq: str | int):
"""添加一条 At 消息到消息链 `chain` 中。

Example:
Expand Down Expand Up @@ -172,15 +173,15 @@ class MessageEventResult(MessageChain):
`result_type` (EventResultType): 事件处理的结果类型。
"""

result_type: Optional[EventResultType] = field(
result_type: EventResultType | None = field(
default_factory=lambda: EventResultType.CONTINUE
)

result_content_type: Optional[ResultContentType] = field(
result_content_type: ResultContentType | None = field(
default_factory=lambda: ResultContentType.GENERAL_RESULT
)

async_stream: Optional[AsyncGenerator] = None
async_stream: AsyncGenerator | None = None
"""异步流"""

def stop_event(self) -> "MessageEventResult":
Expand Down
5 changes: 3 additions & 2 deletions astrbot/core/pipeline/content_safety_check/stage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union, AsyncGenerator

from collections.abc import AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
Expand All @@ -20,7 +21,7 @@ async def initialize(self, ctx: PipelineContext):

async def process(
self, event: AstrMessageEvent, check_text: str = None
) -> Union[None, AsyncGenerator[None, None]]:
) -> None | AsyncGenerator[None, None]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@

class ContentSafetyStrategy(abc.ABC):
@abc.abstractmethod
def check(self, content: str) -> Tuple[bool, str]:
def check(self, content: str) -> tuple[bool, str]:
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from . import ContentSafetyStrategy
from typing import List, Tuple
from astrbot import logger


class StrategySelector:
def __init__(self, config: dict) -> None:
self.enabled_strategies: List[ContentSafetyStrategy] = []
self.enabled_strategies: list[ContentSafetyStrategy] = []
if config["internal_keywords"]["enable"]:
from .keywords import KeywordsStrategy

Expand All @@ -26,7 +25,7 @@ def __init__(self, config: dict) -> None:
)
)

def check(self, content: str) -> Tuple[bool, str]:
def check(self, content: str) -> tuple[bool, str]:
for strategy in self.enabled_strategies:
ok, info = strategy.check(content)
if not ok:
Expand Down
Loading