diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 2f19434c9d..f35102bd9b 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -28,6 +28,7 @@ import sys import uuid from enum import Enum +from pathlib import Path, PurePosixPath if sys.version_info >= (3, 14): from pydantic import BaseModel @@ -662,6 +663,19 @@ class Unknown(BaseMessageComponent): text: str +def _sanitize_file_component_name(name: str | None) -> str: + if not name: + return "file" + + normalized = str(name).replace("\\", "/") + basename = PurePosixPath(normalized).name.replace("\x00", "").strip() + for char in ':*?"<>|': + basename = basename.replace(char, "_") + if basename in {"", ".", ".."}: + return "file" + return basename + + class File(BaseMessageComponent): """文件消息段""" @@ -773,15 +787,18 @@ async def _download_file(self) -> None: """下载文件""" if not self.url: raise ValueError("Download failed: No URL provided in File component.") - download_dir = get_astrbot_temp_path() + download_dir = Path(get_astrbot_temp_path()) + download_dir.mkdir(parents=True, exist_ok=True) if self.name: - name, ext = os.path.splitext(self.name) + safe_name = _sanitize_file_component_name(self.name) + name = Path(safe_name).stem + ext = Path(safe_name).suffix filename = f"fileseg_{name}_{uuid.uuid4().hex[:8]}{ext}" else: filename = f"fileseg_{uuid.uuid4().hex}" - file_path = os.path.join(download_dir, filename) - await download_file(self.url, file_path) - self.file_ = os.path.abspath(file_path) + file_path = download_dir / filename + await download_file(self.url, str(file_path)) + self.file_ = str(file_path.resolve()) async def register_to_file_service(self) -> str: """将文件注册到文件服务。 diff --git a/tests/unit/test_file_message_component.py b/tests/unit/test_file_message_component.py new file mode 100644 index 0000000000..f7ecd121ed --- /dev/null +++ b/tests/unit/test_file_message_component.py @@ -0,0 +1,39 @@ +from pathlib import Path + +import pytest + +from astrbot.core.message import components + + +@pytest.mark.asyncio +async def test_file_component_download_sanitizes_remote_name(monkeypatch, tmp_path): + temp_dir = tmp_path / "temp" + downloaded_paths: list[Path] = [] + + async def fake_download_file(url: str, path: str) -> None: + target = Path(path) + assert url == "https://example.com/report" + assert target.parent == temp_dir + assert target.parent.exists() + assert "\x00" not in target.name + assert "/" not in target.name + assert "\\" not in target.name + assert not any(char in target.name for char in ':*?"<>|') + target.write_bytes(b"payload") + downloaded_paths.append(target) + + monkeypatch.setattr(components, "download_file", fake_download_file) + monkeypatch.setattr(components, "get_astrbot_temp_path", lambda: str(temp_dir)) + + component = components.File( + name='..\\nested/evil\\report:*?"<>|\x00.pdf', + url="https://example.com/report", + ) + + path = Path(await component.get_file()) + + assert path.parent == temp_dir + assert path.exists() + assert path.name.startswith("fileseg_report________") + assert path.suffix == ".pdf" + assert downloaded_paths == [path]