diff --git a/alphatrion/artifact/artifact.py b/alphatrion/artifact/artifact.py index bdaa463..edd5db2 100644 --- a/alphatrion/artifact/artifact.py +++ b/alphatrion/artifact/artifact.py @@ -56,17 +56,21 @@ def list_versions(self, repo_name: str) -> list[str]: return self._backend.list_versions(repo_name) def pull( - self, repo_name: str, version: str, output_dir: str | None = None + self, repo_name: str, version_or_filename: str, output_dir: str | None = None ) -> list[str]: """ Pull artifacts from the storage. :param repo_name: the name of the repository to pull from - :param version: the version (tag) to pull + :param version_or_filename: the version (tag) or filename to pull. + For OCI backend, this is always the tag. + For S3 backend, if this matches a file name under the repo, + that file will be pulled. Otherwise, if this matches a version folder + (e.g., "v1"), all files under that folder will be pulled. :param output_dir: optional directory to save files to :return: list of absolute file paths that were downloaded """ - return self._backend.pull(repo_name, version, output_dir) + return self._backend.pull(repo_name, version_or_filename, output_dir) def delete(self, repo_name: str, versions: str | list[str]): """Delete specific versions from a repository.""" diff --git a/alphatrion/artifact/base.py b/alphatrion/artifact/base.py index 8c82e4a..b8e907f 100644 --- a/alphatrion/artifact/base.py +++ b/alphatrion/artifact/base.py @@ -38,7 +38,7 @@ def pull( """Pull artifacts from the storage. :param repo_name: the name of the repository to pull from - :param version: the version (tag) to pull + :param version: the version (tag) or filename to pull :param output_dir: optional directory to save files to :return: list of absolute file paths that were downloaded """ diff --git a/alphatrion/artifact/s3_backend.py b/alphatrion/artifact/s3_backend.py index c450929..70e6466 100644 --- a/alphatrion/artifact/s3_backend.py +++ b/alphatrion/artifact/s3_backend.py @@ -1,10 +1,3 @@ -"""S3-compatible artifact storage backend (push-only). - -This backend supports pushing artifacts to S3 for archival/backup purposes. -list_versions() and pull() are not implemented - use AWS S3 console, CLI, -or SDK directly to retrieve artifacts if needed. -""" - import os from alphatrion import envs @@ -166,12 +159,14 @@ def list_versions(self, repo_name: str) -> list[str]: raise RuntimeError(f"Failed to list versions: {e}") from e def pull( - self, repo_name: str, version: str, output_dir: str | None = None + self, repo_name: str, version_or_filename: str, output_dir: str | None = None ) -> list[str]: """Pull (download) files from S3. :param repo_name: Repository path (e.g., "org_id/team_id/exp_id/ckpt") - :param version: The filename to download (for flat structure) or folder name (for versioned structure) + :param version_or_filename: For S3 backend, if this matches a file name under the repo, + that file will be pulled. Otherwise, if this matches a version folder (e.g., "v1"), + all files under that folder will be pulled. :param output_dir: Optional directory to save files. If None, downloads to current directory. :return: List of absolute paths to downloaded files """ @@ -182,25 +177,15 @@ def pull( download_dir = os.getcwd() try: - # Check if version looks like a filename (has extension) or version folder - if "." in version: - # Single file: repo_name/version (e.g., "ckpt/checkpoint_123.pt") - s3_key = f"{repo_name}/{version}" - local_path = os.path.join(download_dir, version) - - self._s3.download_file(self._bucket, s3_key, local_path) - return [local_path] - else: - # Version folder: repo_name/version/* (e.g., "ckpt/v1/*") - prefix = f"{repo_name}/{version}/" + # First, try as a version folder (e.g., "v1" or "v1.0") + prefix = f"{repo_name}/{version_or_filename}/" - response = self._s3.list_objects_v2( - Bucket=self._bucket, Prefix=prefix, Delimiter="/" - ) - - if "Contents" not in response: - return [] + response = self._s3.list_objects_v2( + Bucket=self._bucket, Prefix=prefix, Delimiter="/" + ) + if "Contents" in response and len(response["Contents"]) > 0: + # It's a version folder with files downloaded_files = [] for obj in response["Contents"]: s3_key = obj["Key"] @@ -211,6 +196,24 @@ def pull( downloaded_files.append(local_path) return downloaded_files + + # Check if it's an empty folder (no contents) - return empty list + # vs a file that doesn't exist (should raise error) + # We distinguish by trying to download as a file + s3_key = f"{repo_name}/{version_or_filename}" + local_path = os.path.join(download_dir, version_or_filename) + + try: + self._s3.download_file(self._bucket, s3_key, local_path) + return [local_path] + except Exception as download_error: + error_msg = str(download_error).lower() + if "404" in error_msg or "not found" in error_msg: + # Neither folder nor file exists - return empty list + return [] + raise + except RuntimeError: + raise except Exception as e: raise RuntimeError(f"Failed to pull artifacts from S3: {e}") from e diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index 6da6ca0..376bce9 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -1,8 +1,6 @@ import asyncio import contextlib import enum -import os -import shutil import signal import uuid from abc import ABC, abstractmethod @@ -11,11 +9,9 @@ from pydantic import BaseModel, Field, model_validator -from alphatrion import envs from alphatrion.run.run import Run from alphatrion.runtime.contextvars import current_exp_id from alphatrion.runtime.runtime import global_runtime -from alphatrion.snapshot.snapshot import team_path from alphatrion.storage import runtime as storage_runtime from alphatrion.storage.sql_models import FINISHED_STATUS, Status from alphatrion.types import CallableEntry, PostRunHookFn @@ -212,8 +208,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): current_exp_id.reset(self._token) self._runtime.current_experiment = None - self._cleanup_files() - def _start( self, name: str, @@ -498,14 +492,6 @@ def start( ) -> "Experiment": raise NotImplementedError - def _cleanup_files(self): - # remove the whole folder once the experiment is done. - if ( - os.path.exists(team_path()) - and os.getenv(envs.AUTO_CLEANUP, "true").lower() == "true" - ): - shutil.rmtree(team_path(), ignore_errors=True) - def _start_signal_handlers(self): loop = asyncio.get_running_loop() diff --git a/alphatrion/log/load.py b/alphatrion/log/load.py index 4b64eb6..5a1f270 100644 --- a/alphatrion/log/load.py +++ b/alphatrion/log/load.py @@ -35,13 +35,13 @@ async def load_checkpoint( output_dir: str | None = None, ) -> list[str]: """ - Load checkpoint from artifact registry. + Load checkpoint from artifact registry, the path is expected to be in the format of + "org_id/team_id/exp_id/ckpt/". :param id: the id of the experiment. :param version: the version of the checkpoint to load, default is "latest". - For oci backend, version is the tag of the artifact. - For s3 backend, version is the name of the file to load. - If version is "latest", the most recently modified file will be loaded. + If version is "latest", it will load the latest version (for oci backend) or + the file with the latest timestamp (for s3 backend). :param type: the type of the checkpoint, can be "experiment" or "agent", default is "experiment". :param output_dir: the directory to which the checkpoint will be loaded. """ diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index c56098e..7507782 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -37,7 +37,7 @@ async def log_artifact( Takes no arguments and returns nothing. This allows side effects after the artifact is saved, such as logging or cleanup. - :return: the path of the logged artifact. + :return: the path of the logged artifact or None if no artifact was logged (e.g., if paths is empty). OCI format: {org_id}/{team_id}/{exp_id}/{repo_name}:{version} S3 format: {org_id}/{team_id}/{exp_id}/{repo_name}/{version} """ diff --git a/alphatrion/server/graphql/resolvers.py b/alphatrion/server/graphql/resolvers.py index 4e1954c..2f8fe75 100644 --- a/alphatrion/server/graphql/resolvers.py +++ b/alphatrion/server/graphql/resolvers.py @@ -951,7 +951,7 @@ async def list_artifact_files( arf = runtime.storage_runtime().artifact org_id = info.context.org_id file_paths = arf.pull( - repo_name=f"{org_id}/{team_id}/{repo_name}", version=tag + repo_name=f"{org_id}/{team_id}/{repo_name}", version_or_filename=tag ) if not file_paths: @@ -1010,7 +1010,7 @@ async def get_artifact_content( # Pull the artifact - ORAS will manage temp directory # Returns absolute paths to files in ORAS temp directory file_paths = arf.pull( - repo_name=f"{org_id}/{team_id}/{repo_name}", version=tag + repo_name=f"{org_id}/{team_id}/{repo_name}", version_or_filename=tag ) if not file_paths: diff --git a/alphatrion/snapshot/__init__.py b/alphatrion/snapshot/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/alphatrion/snapshot/snapshot.py b/alphatrion/snapshot/snapshot.py deleted file mode 100644 index a16cd8d..0000000 --- a/alphatrion/snapshot/snapshot.py +++ /dev/null @@ -1,61 +0,0 @@ -from pathlib import Path - -from alphatrion.runtime.contextvars import current_exp_id, current_run_id -from alphatrion.runtime.runtime import global_runtime - -"""The snapshot is organized in a hierarchical directory structure as follows: -└── snapshots - ├── team_26c73273-dc1f-40f2-a4ed-7eee35ae05d4 - │   └── project_7481eedd-86be-4d39-9d0d-679ba83ab7b6 - │   └── user_95eb9982-56fd-4e1a-8498-28312f4d3ba5 - │   └── exp_e5b08511-eba4-4ee7-bd10-e632110996b5 - │   └── run_5b08673b-b2cb-41fb-90f0-c15b450f4433 - └── team_dcafdce3-dfde-47b4-994c-93880848ca91 - ├── project_449a560d-cc12-46eb-9058-351cfe56433b - │   ├── user_450704dd-37f4-4aa7-97f8-b34e42576b09 - │   │   ├── exp_c855725d-891f-4f61-8f7d-b6f40c94509f - │   │   └── exp_f751ec9c-4aa2-46c5-8cab-1f92af6f001d - │   │   ├── checkpoints - │   │   ├── run_94a82594-01a7-463f-b63b-ab896be9830e - │   │   │   └── result.json - │   │   └── run_c0e3c730-c213-4a8e-9e10-7af57fcf8bf9 - │   └── user_f303c129-c4b5-4f24-957c-d28dd78cce89 - │   └── exp_efeb5430-6593-4675-969c-325aa25af986 - │   ├── run_1c89b44d-15de-464a-9e1c-c6aab8a82a7d - │   └── run_7990bcf3-f864-4442-ae35-00dd8329f7c5 - └── project_5cb62b0b-83d3-49fa-956e-cd51df3e7891 -""" - - -def snapshot_path() -> str: - runtime = global_runtime() - return ( - Path(runtime.root_path) - / "snapshots" - / f"team_{runtime.team_id}" - / f"user_{runtime.user_id}" - / f"exp_{current_exp_id.get()}" - / f"run_{current_run_id.get()}" - ) - - -def checkpoint_path() -> str: - runtime = global_runtime() - return ( - Path(runtime.root_path) - / "snapshots" - / f"org_{runtime.org_id}" - / f"team_{runtime.team_id}" - / f"exp_{current_exp_id.get()}" - / "checkpoints" - ) - - -def team_path() -> str: - runtime = global_runtime() - return ( - Path(runtime.root_path) - / "snapshots" - / f"org_{runtime.org_id}" - / f"team_{runtime.team_id}" - ) diff --git a/tests/integration/test_oci_backend.py b/tests/integration/test_oci_backend.py index 9d215a6..bf76ff4 100644 --- a/tests/integration/test_oci_backend.py +++ b/tests/integration/test_oci_backend.py @@ -176,7 +176,7 @@ def test_oci_backend_pull_single_file(artifact, unique_repo): # Pull the file output_dir = os.path.join(tmpdir, "download") result = artifact.pull( - repo_name=unique_repo, version="v1", output_dir=output_dir + repo_name=unique_repo, version_or_filename="v1", output_dir=output_dir ) # Verify file was downloaded @@ -205,7 +205,7 @@ def test_oci_backend_pull_multiple_files(artifact, unique_repo): # Pull the files output_dir = os.path.join(tmpdir, "download") result = artifact.pull( - repo_name=unique_repo, version="v1", output_dir=output_dir + repo_name=unique_repo, version_or_filename="v1", output_dir=output_dir ) # Verify all files were downloaded @@ -239,7 +239,7 @@ def test_oci_backend_pull_to_current_dir(artifact, unique_repo): artifact.push(repo_name=unique_repo, paths=test_file, version="v1") # Pull without output_dir - result = artifact.pull(repo_name=unique_repo, version="v1") + result = artifact.pull(repo_name=unique_repo, version_or_filename="v1") # Should download to current directory assert len(result) == 1 diff --git a/tests/unit/artifact/test_s3_backend.py b/tests/unit/artifact/test_s3_backend.py index 218bb15..385f81e 100644 --- a/tests/unit/artifact/test_s3_backend.py +++ b/tests/unit/artifact/test_s3_backend.py @@ -317,7 +317,7 @@ def test_s3_backend_pull_single_file(s3_client): output_dir = os.path.join(tmpdir, "download") result = artifact.pull( repo_name="org123/team456/exp1/ckpt", - version="checkpoint.pt", + version_or_filename="checkpoint.pt", output_dir=output_dir, ) @@ -351,7 +351,7 @@ def test_s3_backend_pull_version_folder(s3_client): # Pull the version folder output_dir = os.path.join(tmpdir, "download") result = artifact.pull( - repo_name="org123/team456/exp1/ckpt", version="v1", output_dir=output_dir + repo_name="org123/team456/exp1/ckpt", version_or_filename="v1", output_dir=output_dir ) # Verify all files were downloaded @@ -385,7 +385,7 @@ def test_s3_backend_pull_to_current_dir(s3_client): # Pull without output_dir result = artifact.pull( - repo_name="org123/team456/test-repo", version="test.txt" + repo_name="org123/team456/test-repo", version_or_filename="test.txt" ) # Should download to current directory @@ -400,18 +400,18 @@ def test_s3_backend_pull_to_current_dir(s3_client): def test_s3_backend_pull_nonexistent_file(s3_client): - """Test pull with non-existent file raises error.""" + """Test pull with non-existent file returns empty list.""" from alphatrion.artifact.artifact import Artifact artifact = Artifact() with tempfile.TemporaryDirectory() as tmpdir: - with pytest.raises(RuntimeError, match="Failed to pull artifacts"): - artifact.pull( - repo_name="org123/team456/nonexistent", - version="missing.txt", - output_dir=tmpdir, - ) + result = artifact.pull( + repo_name="org123/team456/nonexistent", + version_or_filename="missing.txt", + output_dir=tmpdir, + ) + assert result == [] def test_s3_backend_pull_empty_version_folder(s3_client): @@ -423,7 +423,7 @@ def test_s3_backend_pull_empty_version_folder(s3_client): with tempfile.TemporaryDirectory() as tmpdir: # Pull non-existent version folder result = artifact.pull( - repo_name="org123/team456/exp1/ckpt", version="v999", output_dir=tmpdir + repo_name="org123/team456/exp1/ckpt", version_or_filename="v999", output_dir=tmpdir ) # Should return empty list diff --git a/tests/unit/experiment/test_experiment.py b/tests/unit/experiment/test_experiment.py index 72c8f54..f82720c 100644 --- a/tests/unit/experiment/test_experiment.py +++ b/tests/unit/experiment/test_experiment.py @@ -5,9 +5,7 @@ import uuid from datetime import datetime, timedelta from functools import partial -from pathlib import Path -import faker import pytest from alphatrion.experiment import base as experiment @@ -18,7 +16,6 @@ from alphatrion.experiment.craft_experiment import CraftExperiment from alphatrion.runtime.contextvars import current_exp_id from alphatrion.runtime.runtime import global_runtime, init -from alphatrion.snapshot.snapshot import checkpoint_path from alphatrion.storage.sql_models import Status @@ -119,25 +116,6 @@ def test_config(self): ), ) - -@pytest.mark.asyncio -async def test_snapshot_path(): - team_id = uuid.uuid4() - user_id = uuid.uuid4() - org_id = uuid.uuid4() - init(team_id=team_id, user_id=user_id, org_id=org_id) - - async with CraftExperiment.start(name=faker.Faker().word()) as exp: - assert checkpoint_path() == ( - Path(exp._runtime.root_path) - / "snapshots" - / f"org_{org_id}" - / f"team_{team_id}" - / f"exp_{exp.id}" - / "checkpoints" - ) - - @pytest.mark.asyncio async def test_experiment_with_done(): init(