Skip to content

Commit

Permalink
fix(bot): 修复加载插件时会自动重新加载插件所在模块的问题 (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
st1020 committed Aug 1, 2023
1 parent 9764cc2 commit a71a42d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
30 changes: 23 additions & 7 deletions alicebot/bot.py
Expand Up @@ -338,7 +338,9 @@ async def _run_hot_reload(self) -> None:

if change_type == Change.added:
logger.info(f"Hot reload: Added file: {file}")
self._load_plugins(Path(file), plugin_load_type=PluginLoadType.DIR)
self._load_plugins(
Path(file), plugin_load_type=PluginLoadType.DIR, reload=True
)
self._update_config()
continue
if change_type == Change.deleted:
Expand All @@ -348,7 +350,9 @@ async def _run_hot_reload(self) -> None:
elif change_type == Change.modified:
logger.info(f"Hot reload: Modified file: {file}")
self._remove_plugin_by_path(file)
self._load_plugins(Path(file), plugin_load_type=PluginLoadType.DIR)
self._load_plugins(
Path(file), plugin_load_type=PluginLoadType.DIR, reload=True
)
self._update_config()

def _update_config(self) -> None:
Expand Down Expand Up @@ -665,11 +669,17 @@ def _load_plugin_class(
)

def _load_plugins_from_module_name(
self, module_name: str, plugin_load_type: PluginLoadType
self,
module_name: str,
*,
plugin_load_type: PluginLoadType,
reload: bool = False,
) -> None:
"""从模块名称中插件模块。"""
try:
plugin_classes = get_classes_from_module_name(module_name, Plugin)
plugin_classes = get_classes_from_module_name(
module_name, Plugin, reload=reload
)
except ImportError as e:
error_or_exception(
f'Import module "{module_name}" failed:',
Expand All @@ -688,6 +698,7 @@ def _load_plugins(
self,
*plugins: Union[Type[Plugin[Any, Any, Any]], str, Path],
plugin_load_type: Optional[PluginLoadType] = None,
reload: bool = False,
) -> None:
"""加载插件。
Expand All @@ -699,6 +710,7 @@ def _load_plugins(
如果为 `pathlib.Path` 类型时,将作为插件模块文件路径进行加载。
例如:`pathlib.Path("path/of/plugin")`。
plugin_load_type: 插件加载类型,如果为 `None` 则自动判断,否则使用指定的类型。
reload: 是否重新加载模块。
"""
for plugin_ in plugins:
if isinstance(plugin_, type):
Expand All @@ -713,7 +725,9 @@ def _load_plugins(
elif isinstance(plugin_, str):
logger.info(f'Loading plugins from module "{plugin_}"')
self._load_plugins_from_module_name(
plugin_, plugin_load_type or PluginLoadType.NAME
plugin_,
plugin_load_type=plugin_load_type or PluginLoadType.NAME,
reload=reload,
)
elif isinstance(plugin_, Path):
logger.info(f'Loading plugins from path "{plugin_}"')
Expand Down Expand Up @@ -744,7 +758,9 @@ def _load_plugins(
)

self._load_plugins_from_module_name(
plugin_module_name, plugin_load_type or PluginLoadType.FILE
plugin_module_name,
plugin_load_type=plugin_load_type or PluginLoadType.FILE,
reload=reload,
)
else:
logger.error(f'The plugin path "{plugin_}" must be a file')
Expand Down Expand Up @@ -781,7 +797,7 @@ def _load_plugins_from_dirs(self, *dirs: Path) -> None:
for module_info in pkgutil.iter_modules(dir_list):
if not module_info.name.startswith("_"):
self._load_plugins_from_module_name(
module_info.name, PluginLoadType.DIR
module_info.name, plugin_load_type=PluginLoadType.DIR
)

def load_plugins_from_dirs(self, *dirs: Path) -> None:
Expand Down
6 changes: 4 additions & 2 deletions alicebot/utils.py
Expand Up @@ -118,13 +118,14 @@ def get_classes_from_module(module: ModuleType, super_class: _TypeT) -> List[_Ty


def get_classes_from_module_name(
name: str, super_class: _TypeT
name: str, super_class: _TypeT, *, reload: bool = False
) -> List[Tuple[_TypeT, ModuleType]]:
"""从指定名称的模块中查找指定类型的类。
Args:
name: 模块名称,格式和 Python `import` 语句相同。
super_class: 要查找的类的超类。
reload: 是否重新加载模块。
Returns:
返回由符合条件的类和模块组成的元组的列表。
Expand All @@ -135,7 +136,8 @@ def get_classes_from_module_name(
try:
importlib.invalidate_caches()
module = importlib.import_module(name)
importlib.reload(module)
if reload:
importlib.reload(module)
return [(x, module) for x in get_classes_from_module(module, super_class)]
except KeyboardInterrupt:
# 不捕获 KeyboardInterrupt
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -144,6 +144,7 @@ disable = [
"arguments-differ",
"broad-exception-caught",
"import-outside-toplevel",
"too-many-lines",
"duplicate-code",
"too-few-public-methods",
"too-many-arguments",
Expand Down

0 comments on commit a71a42d

Please sign in to comment.