Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions astrbot/core/computer/booters/cua.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,43 @@ def _has_component_method(root: Any, component_name: str, method_name: str) -> b
return getattr(component, method_name, None) is not None


def _resolve_files_components(sandbox: Any) -> tuple[Any, ...]:
components: list[Any] = []
seen_ids: set[int] = set()
for name in ("files", "filesystem"):
component = getattr(sandbox, name, None)
if component is None:
continue
component_id = id(component)
if component_id in seen_ids:
continue
seen_ids.add(component_id)
components.append(component)
return tuple(components)


def _resolve_files_method(
components: tuple[Any, ...],
method_names: str | tuple[str, ...],
) -> Any | None:
for component in components:
method = _resolve_component_method(component, method_names)
if method is not None:
return method
return None


def _normalize_native_upload_result(raw: Any, file_name: str) -> dict[str, Any]:
payload = _maybe_model_dump(raw)
if not payload:
return {"success": True, "file_path": file_name}
if "file_path" not in payload and "path" not in payload:
payload["file_path"] = file_name
Comment on lines +265 to +266
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic fails to ensure the file_path key is present if the payload contains a path key instead. According to the ComputerBooter.upload_file contract, the returned dictionary must contain a file_path key. If path is present in the native response, it should be mapped to file_path to satisfy the interface requirements. Additionally, as this handles file attachments, ensure this logic is accompanied by unit tests.

Suggested change
if "file_path" not in payload and "path" not in payload:
payload["file_path"] = file_name
if "file_path" not in payload:
payload["file_path"] = payload.get("path") or file_name
References
  1. New functionality, such as handling attachments, should be accompanied by corresponding unit tests.

if "success" not in payload:
payload["success"] = not bool(payload.get("error") or payload.get("stderr"))
return payload


class CuaShellComponent(ShellComponent):
def __init__(self, sandbox: Any, os_type: str = "linux") -> None:
self._sandbox = sandbox
Expand Down Expand Up @@ -360,7 +397,7 @@ def __init__(
self, sandbox: Any, os_type: str = CUA_DEFAULT_CONFIG["os_type"]
) -> None:
self._shell = CuaShellComponent(sandbox, os_type=os_type)
self._fs = getattr(sandbox, "filesystem", None)
self._fs_components = _resolve_files_components(sandbox)
self._os_type = os_type.lower()
self._fallback = _PosixShellFileSystem(self._shell, self._os_type)

Expand All @@ -382,7 +419,9 @@ async def read_file(
offset: int | None = None,
limit: int | None = None,
) -> dict[str, Any]:
read_file = None if self._fs is None else getattr(self._fs, "read_file", None)
read_file = _resolve_files_method(
self._fs_components, ("read_file", "read_text")
)
if read_file is None:
return await self._fallback.read_file(path, encoding, offset, limit)
else:
Expand All @@ -405,19 +444,19 @@ async def write_file(
encoding: str = "utf-8",
) -> dict[str, Any]:
_ = mode
write_file = None if self._fs is None else getattr(self._fs, "write_file", None)
write_file = _resolve_files_method(
self._fs_components, ("write_file", "write_text")
)
if write_file is None:
return await self._fallback.write_file(path, content, mode, encoding)
else:
await _maybe_await(write_file(path, content))
return {"success": True, "path": path}

async def delete_file(self, path: str) -> dict[str, Any]:
delete = None
if self._fs is not None:
delete = getattr(self._fs, "delete", None) or getattr(
self._fs, "delete_file", None
)
delete = _resolve_files_method(
self._fs_components, ("delete", "delete_file", "remove")
)
if delete is None:
return await self._fallback.delete_file(path)
else:
Expand All @@ -429,7 +468,7 @@ async def list_dir(
path: str = ".",
show_hidden: bool = False,
) -> dict[str, Any]:
list_dir = None if self._fs is None else getattr(self._fs, "list_dir", None)
list_dir = _resolve_files_method(self._fs_components, ("list_dir", "list"))
if list_dir is not None:
entries = await _maybe_await(list_dir(path))
return {"success": True, "path": path, "entries": entries}
Expand Down Expand Up @@ -802,6 +841,15 @@ async def upload_file(self, path: str, file_name: str) -> dict:
return _maybe_model_dump(
await sandbox.upload_file(str(local_path), file_name)
)
Comment on lines 841 to 843
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and to strictly adhere to the upload_file contract (which requires success and file_path keys), the legacy sandbox.upload_file result should also be processed by _normalize_native_upload_result. This refactors the logic into a shared helper function to avoid code duplication and ensures that new attachment handling functionality is properly normalized and tested.

            result = await sandbox.upload_file(str(local_path), file_name)
            return _normalize_native_upload_result(result, file_name)
References
  1. When implementing similar functionality for different cases (e.g., direct vs. quoted attachments), refactor the logic into a shared helper function to avoid code duplication.
  2. New functionality, such as handling attachments, should be accompanied by corresponding unit tests.

files_components = () if sandbox is None else _resolve_files_components(sandbox)
upload = _resolve_files_method(files_components, "upload")
if upload is not None:
result = await _maybe_await(upload(str(local_path), file_name))
return _normalize_native_upload_result(result, file_name)
write_bytes = _resolve_files_method(files_components, "write_bytes")
if write_bytes is not None:
result = await _maybe_await(write_bytes(file_name, local_path.read_bytes()))
return _normalize_native_upload_result(result, file_name)
if not _is_posix_os_type(self.os_type):
return _non_posix_filesystem_result(file_name, self.os_type)
result = await _write_base64_via_shell(
Expand Down
162 changes: 162 additions & 0 deletions tests/unit/test_cua_computer_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@ async def list_dir(self, path: str):
return [path]


class FakeFiles:
def __init__(self):
self.uploads = []
self.byte_writes = []
self.text_writes = []
self.text_reads = {}

async def upload(self, local_path: str, remote_path: str):
self.uploads.append((local_path, remote_path))

async def write_bytes(self, path: str, content: bytes):
self.byte_writes.append((path, content))

async def write_text(self, path: str, content: str):
self.text_writes.append((path, content))
self.text_reads[path] = content

async def read_text(self, path: str):
return self.text_reads[path]


class FakeMouse:
def __init__(self):
self.clicks = []
Expand Down Expand Up @@ -536,6 +557,41 @@ async def test_cua_shell_and_python_accept_sync_sdk_methods():
assert python_result["data"]["output"]["text"] == "sync"


@pytest.mark.asyncio
async def test_cua_filesystem_prefers_native_files_interface():
from astrbot.core.computer.booters.cua import CuaFileSystemComponent

sandbox = SandboxWithoutFilesystem()
sandbox.files = FakeFiles()

fs = CuaFileSystemComponent(sandbox)
await fs.write_file("hello.txt", "hello")
result = await fs.read_file("hello.txt")

assert sandbox.files.text_writes == [("hello.txt", "hello")]
assert result["success"] is True
assert result["content"] == "hello"
assert sandbox.shell.commands == []


@pytest.mark.asyncio
async def test_cua_filesystem_uses_legacy_filesystem_when_files_lacks_method():
from astrbot.core.computer.booters.cua import CuaFileSystemComponent

sandbox = SandboxWithoutFilesystem()
sandbox.files = type("UploadOnlyFiles", (), {"upload": FakeFiles().upload})()
sandbox.filesystem = FakeFilesystem()

fs = CuaFileSystemComponent(sandbox)
await fs.write_file("hello.txt", "hello")
result = await fs.read_file("hello.txt")

assert sandbox.filesystem.files == {"hello.txt": "hello"}
assert result["success"] is True
assert result["content"] == "hello"
assert sandbox.shell.commands == []


@pytest.mark.asyncio
async def test_cua_shell_normalizes_output_returncode_shape():
from astrbot.core.computer.booters.cua import CuaShellComponent
Expand Down Expand Up @@ -679,6 +735,112 @@ async def test_cua_upload_file_fallback_rejects_non_posix_os_type(tmp_path):
assert sandbox.shell.commands == []


@pytest.mark.asyncio
async def test_cua_upload_file_prefers_native_files_upload(tmp_path):
from astrbot.core.computer.booters.cua import (
CuaBooter,
CuaFileSystemComponent,
CuaGUIComponent,
CuaPythonComponent,
CuaShellComponent,
_CuaRuntime,
)

local_file = tmp_path / "upload.txt"
local_file.write_text("hello", encoding="utf-8")
sandbox = SandboxWithoutFilesystem()
sandbox.files = FakeFiles()
booter = CuaBooter()
booter._runtime = _CuaRuntime(
sandbox_cm=object(),
sandbox=sandbox,
shell=CuaShellComponent(sandbox),
python=CuaPythonComponent(sandbox),
fs=CuaFileSystemComponent(sandbox),
gui=CuaGUIComponent(sandbox),
)

result = await booter.upload_file(str(local_file), "remote.txt")

assert result["success"] is True
assert sandbox.files.uploads == [(str(local_file), "remote.txt")]
assert sandbox.shell.commands == []


@pytest.mark.asyncio
async def test_cua_upload_file_uses_native_write_bytes_when_upload_missing(tmp_path):
from astrbot.core.computer.booters.cua import (
CuaBooter,
CuaFileSystemComponent,
CuaGUIComponent,
CuaPythonComponent,
CuaShellComponent,
_CuaRuntime,
)

class FilesWithoutUpload:
def __init__(self):
self.byte_writes = []

async def write_bytes(self, path: str, content: bytes):
self.byte_writes.append((path, content))

local_file = tmp_path / "upload.txt"
local_file.write_bytes(b"hello-bytes")
sandbox = SandboxWithoutFilesystem()
sandbox.files = FilesWithoutUpload()
booter = CuaBooter()
booter._runtime = _CuaRuntime(
sandbox_cm=object(),
sandbox=sandbox,
shell=CuaShellComponent(sandbox),
python=CuaPythonComponent(sandbox),
fs=CuaFileSystemComponent(sandbox),
gui=CuaGUIComponent(sandbox),
)

result = await booter.upload_file(str(local_file), "remote.txt")

assert result["success"] is True
assert sandbox.files.byte_writes == [("remote.txt", b"hello-bytes")]
assert sandbox.shell.commands == []


@pytest.mark.asyncio
async def test_cua_upload_file_propagates_native_upload_failure_result(tmp_path):
from astrbot.core.computer.booters.cua import (
CuaBooter,
CuaFileSystemComponent,
CuaGUIComponent,
CuaPythonComponent,
CuaShellComponent,
_CuaRuntime,
)

class FailingFilesUpload:
async def upload(self, local_path: str, remote_path: str):
return {"success": False, "error": "disk full"}

local_file = tmp_path / "upload.txt"
local_file.write_text("hello", encoding="utf-8")
sandbox = SandboxWithoutFilesystem()
sandbox.files = FailingFilesUpload()
booter = CuaBooter()
booter._runtime = _CuaRuntime(
sandbox_cm=object(),
sandbox=sandbox,
shell=CuaShellComponent(sandbox),
python=CuaPythonComponent(sandbox),
fs=CuaFileSystemComponent(sandbox),
gui=CuaGUIComponent(sandbox),
)

result = await booter.upload_file(str(local_file), "remote.txt")

assert result["success"] is False
assert result["error"] == "disk full"


@pytest.mark.asyncio
async def test_cua_download_file_shell_quotes_remote_path(tmp_path):
from astrbot.core.computer.booters.cua import (
Expand Down
Loading