Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion astrbot/core/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .whitelist_check.stage import WhitelistCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .plugin_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core import logger
from astrbot.core.platform.message_type import MessageType


@register_stage
Expand All @@ -28,6 +29,11 @@ async def process(
# 获取当前平台ID
platform_id = event.get_platform_id()

# 添加群聊ID日志
group_id = None
if event.get_message_type() == MessageType.GROUP_MESSAGE:
group_id = f"{event.get_platform_name()}:{event.get_group_id()}"

# 获取已激活的处理器
activated_handlers = event.get_extra("activated_handlers")
if activated_handlers is None:
Expand All @@ -37,13 +43,14 @@ async def process(
for handler in activated_handlers:
if not isinstance(handler, StarHandlerMetadata):
continue
# 检查处理器是否在当前平台启用
enabled = handler.is_enabled_for_platform(platform_id)

if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name

enabled = handler.is_enabled_for_platform(platform_id, group_id)
if not enabled:
if handler.handler_module_path in star_map:
plugin_name = star_map[handler.handler_module_path].name
logger.debug(
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
f"[权限调试] 插件 {plugin_name} 在平台 {platform_id} {'和群聊 ' + group_id if group_id else ''} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
)
# 设置处理器为平台不兼容状态
# TODO: 更好的标记方式
Expand Down
10 changes: 8 additions & 2 deletions astrbot/core/pipeline/process_stage/method/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,15 @@ async def process(
# 执行请求 LLM 前事件钩子。
# 装饰 system_prompt 等功能
# 获取当前平台ID
platform_id = event.get_platform_id()
# 构建群聊标识符
group_id = None
if event.get_group_id():
group_id = f"{event.get_platform_name()}:{event.get_group_id()}"

handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMRequestEvent, platform_id=platform_id
EventType.OnAfterMessageSentEvent,
platform_id=event.get_platform_id(),
group_id=group_id,
)
for handler in handlers:
try:
Expand Down
9 changes: 8 additions & 1 deletion astrbot/core/pipeline/respond/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,15 @@ async def process(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
)

# 构建群聊标识符
group_id = None
if event.get_group_id():
group_id = f"{event.get_platform_name()}:{event.get_group_id()}"

handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
EventType.OnAfterMessageSentEvent,
platform_id=event.get_platform_id(),
group_id=group_id,
)
for handler in handlers:
try:
Expand Down
9 changes: 8 additions & 1 deletion astrbot/core/pipeline/waking_check/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,15 @@ async def process(
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler

# 构建群聊标识符
group_id = None
if event.get_group_id():
group_id = f"{event.get_platform_name()}:{event.get_group_id()}"

for handler in star_handlers_registry.get_handlers_by_event_type(
EventType.AdapterMessageEvent
EventType.AdapterMessageEvent,
platform_id=event.get_platform_id(),
group_id=group_id,
):
# filter 需满足 AND 逻辑关系
passed = True
Expand Down
59 changes: 51 additions & 8 deletions astrbot/core/star/star.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Dict
from dataclasses import dataclass, field
from astrbot.core.config import AstrBotConfig
from astrbot.core import logger

star_registry: List[StarMetadata] = []
star_map: Dict[str, StarMetadata] = {}
Expand Down Expand Up @@ -50,26 +51,68 @@ class StarMetadata:
supported_platforms: Dict[str, bool] = field(default_factory=dict)
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""

group_permissions: Dict[str, bool] = field(default_factory=dict)
"""插件在群聊中的权限缓存,key为群聊ID,value为插件是否启用"""

def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"

def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
def update_plugin_compatibility(self, plugin_enable_config: dict) -> None:
"""更新插件支持的平台列表

Args:
plugin_enable_config: 平台插件启用配置,即platform_settings.plugin_enable配置项
"""
self.update_platform_config(plugin_enable_config)
self.update_group_permissions()

logger.debug(
f"[权限调试] 插件 {self.name} 最终兼容性配置: {self.supported_platforms}"
)
logger.debug(
f"[权限调试] 插件 {self.name} 群聊黑名单: {self.group_permissions}"
)

def update_platform_config(self, plugin_enable_config: dict) -> None:
if not plugin_enable_config:
return

# 清空之前的配置
self.supported_platforms.clear()

# 遍历所有平台配置
for platform_id, plugins in plugin_enable_config.items():
# 检查该插件在当前平台的配置
# 处理平台配置
if plugin_enable_config:
# 遍历所有平台配置
for platform_id, plugins in plugin_enable_config.items():
# 检查该插件在当前平台的配置
if self.name in plugins:
self.supported_platforms[platform_id] = plugins[self.name]
logger.debug(
f"[权限调试] 设置平台配置: {platform_id} = {plugins[self.name]}"
)
else:
# 如果没有明确配置,默认为启用
self.supported_platforms[platform_id] = True
logger.debug(f"[权限调试] 默认启用平台: {platform_id}")

def update_group_permissions(self) -> None:
from astrbot.core import astrbot_config

# 清空并更新群聊权限缓存
self.group_permissions.clear()

group_settings = astrbot_config.get("group_settings", {})
plugin_enable = group_settings.get("plugin_enable", {})

# 遍历所有群聊配置,只缓存被禁用的(黑名单方式)
for group_id, plugins in plugin_enable.items():
logger.debug(f"[权限调试] 处理群聊 {group_id} 的插件配置: {plugins}")
if self.name in plugins:
self.supported_platforms[platform_id] = plugins[self.name]
else:
# 如果没有明确配置,默认为启用
self.supported_platforms[platform_id] = True
is_enabled = plugins[self.name]
# 只缓存被禁用的群聊(黑名单方式)
if not is_enabled:
self.group_permissions[group_id] = False
logger.debug(
f"[权限调试] 缓存禁用的群聊 {group_id} 配置: {is_enabled}"
)
logger.debug(f"[权限调试] 设置群聊配置: {group_id} = {is_enabled}")
34 changes: 30 additions & 4 deletions astrbot/core/star/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,23 @@ def _print_handlers(self):
print(handler.handler_full_name)

def get_handlers_by_event_type(
self, event_type: EventType, only_activated=True, platform_id=None
self,
event_type: EventType,
only_activated=True,
platform_id=None,
group_id=None,
) -> List[StarHandlerMetadata]:
"""通过事件类型获取 Handler

Args:
event_type: 事件类型
only_activated: 是否只返回已激活的插件的处理器
platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器
group_id: "平台名称:群组ID" event.get_platform_name():event.get_group_id()

Returns:
List[StarHandlerMetadata]: 处理器列表
"""
handlers = []
for handler in self._handlers:
if handler.event_type != event_type:
Expand All @@ -37,7 +52,7 @@ def get_handlers_by_event_type(
if not (plugin and plugin.activated):
continue
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
if not handler.is_enabled_for_platform(platform_id):
if not handler.is_enabled_for_platform(platform_id, group_id):
continue
handlers.append(handler)
return handlers
Expand Down Expand Up @@ -120,15 +135,17 @@ def __lt__(self, other: StarHandlerMetadata):
"priority", 0
)

def is_enabled_for_platform(self, platform_id: str) -> bool:
def is_enabled_for_platform(self, platform_id: str, group_id=None) -> bool:
"""检查插件是否在指定平台启用

Args:
platform_id: 平台ID,这是从event.get_platform_id()获取的,用于唯一标识平台实例
group_id: "平台名称:群组ID" event.get_platform_name():event.get_group_id()

Returns:
bool: 是否启用,True表示启用,False表示禁用
"""

plugin = star_map.get(self.handler_module_path)

# 如果插件元数据不存在,默认允许执行
Expand All @@ -139,12 +156,21 @@ def is_enabled_for_platform(self, platform_id: str) -> bool:
if not plugin.activated:
return False

# 检查群聊插件权限设置 - 使用缓存的群聊权限
if group_id and hasattr(plugin, "group_permissions"):
# 如果群组ID在黑名单中(被明确禁用),则返回False
if group_id in plugin.group_permissions:
is_enabled = plugin.group_permissions[group_id]
# 由于我们只缓存禁用的群聊,所以这里应该总是False
return is_enabled

# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
if (
hasattr(plugin, "supported_platforms")
and platform_id in plugin.supported_platforms
):
return plugin.supported_platforms[platform_id]
is_enabled = plugin.supported_platforms[platform_id]
return is_enabled

# 如果没有缓存数据,默认允许执行
return True
10 changes: 5 additions & 5 deletions astrbot/core/star/star_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ async def reload(self, specified_plugin_name=None):
result = await self.load(specified_module_path)

# 更新所有插件的平台兼容性
await self.update_all_platform_compatibility()
await self.update_all_plugin_compatibility()

return result

async def update_all_platform_compatibility(self):
async def update_all_plugin_compatibility(self):
"""更新所有插件的平台兼容性设置"""
# 获取最新的平台插件启用配置
plugin_enable_config = self.config.get("platform_settings", {}).get(
Expand All @@ -350,9 +350,9 @@ async def update_all_platform_compatibility(self):

# 遍历所有插件,更新平台兼容性
for plugin in self.context.get_all_stars():
plugin.update_platform_compatibility(plugin_enable_config)
plugin.update_plugin_compatibility(plugin_enable_config)
logger.debug(
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
f"[权限调试] 插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
)

return True
Expand Down Expand Up @@ -479,7 +479,7 @@ async def load(self, specified_module_path=None, specified_dir_name=None):
plugin_enable_config = self.config.get("platform_settings", {}).get(
"plugin_enable", {}
)
metadata.update_platform_compatibility(plugin_enable_config)
metadata.update_plugin_compatibility(plugin_enable_config)

# 绑定 handler
related_handlers = (
Expand Down
Loading