Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions alphatrion/artifact/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,21 @@ def list_versions(self, repo_name: str) -> list[str]:
return self._backend.list_versions(repo_name)

def pull(
self, repo_name: str, version: str, output_dir: str | None = None
self, repo_name: str, version_or_filename: str, output_dir: str | None = None
) -> list[str]:
"""
Pull artifacts from the storage.

:param repo_name: the name of the repository to pull from
:param version: the version (tag) to pull
:param version_or_filename: the version (tag) or filename to pull.
For OCI backend, this is always the tag.
For S3 backend, if this matches a file name under the repo,
that file will be pulled. Otherwise, if this matches a version folder
(e.g., "v1"), all files under that folder will be pulled.
:param output_dir: optional directory to save files to
:return: list of absolute file paths that were downloaded
"""
return self._backend.pull(repo_name, version, output_dir)
return self._backend.pull(repo_name, version_or_filename, output_dir)

def delete(self, repo_name: str, versions: str | list[str]):
"""Delete specific versions from a repository."""
Expand Down
2 changes: 1 addition & 1 deletion alphatrion/artifact/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def pull(
"""Pull artifacts from the storage.

:param repo_name: the name of the repository to pull from
:param version: the version (tag) to pull
:param version: the version (tag) or filename to pull
:param output_dir: optional directory to save files to
:return: list of absolute file paths that were downloaded
"""
Expand Down
55 changes: 29 additions & 26 deletions alphatrion/artifact/s3_backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
"""S3-compatible artifact storage backend (push-only).

This backend supports pushing artifacts to S3 for archival/backup purposes.
list_versions() and pull() are not implemented - use AWS S3 console, CLI,
or SDK directly to retrieve artifacts if needed.
"""

import os

from alphatrion import envs
Expand Down Expand Up @@ -166,12 +159,14 @@ def list_versions(self, repo_name: str) -> list[str]:
raise RuntimeError(f"Failed to list versions: {e}") from e

def pull(
self, repo_name: str, version: str, output_dir: str | None = None
self, repo_name: str, version_or_filename: str, output_dir: str | None = None
) -> list[str]:
"""Pull (download) files from S3.

:param repo_name: Repository path (e.g., "org_id/team_id/exp_id/ckpt")
:param version: The filename to download (for flat structure) or folder name (for versioned structure)
:param version_or_filename: For S3 backend, if this matches a file name under the repo,
that file will be pulled. Otherwise, if this matches a version folder (e.g., "v1"),
all files under that folder will be pulled.
:param output_dir: Optional directory to save files. If None, downloads to current directory.
:return: List of absolute paths to downloaded files
"""
Expand All @@ -182,25 +177,15 @@ def pull(
download_dir = os.getcwd()

try:
# Check if version looks like a filename (has extension) or version folder
if "." in version:
# Single file: repo_name/version (e.g., "ckpt/checkpoint_123.pt")
s3_key = f"{repo_name}/{version}"
local_path = os.path.join(download_dir, version)

self._s3.download_file(self._bucket, s3_key, local_path)
return [local_path]
else:
# Version folder: repo_name/version/* (e.g., "ckpt/v1/*")
prefix = f"{repo_name}/{version}/"
# First, try as a version folder (e.g., "v1" or "v1.0")
prefix = f"{repo_name}/{version_or_filename}/"

response = self._s3.list_objects_v2(
Bucket=self._bucket, Prefix=prefix, Delimiter="/"
)

if "Contents" not in response:
return []
response = self._s3.list_objects_v2(
Bucket=self._bucket, Prefix=prefix, Delimiter="/"
)

if "Contents" in response and len(response["Contents"]) > 0:
# It's a version folder with files
downloaded_files = []
for obj in response["Contents"]:
s3_key = obj["Key"]
Expand All @@ -211,6 +196,24 @@ def pull(
downloaded_files.append(local_path)

return downloaded_files

# Check if it's an empty folder (no contents) - return empty list
# vs a file that doesn't exist (should raise error)
# We distinguish by trying to download as a file
s3_key = f"{repo_name}/{version_or_filename}"
local_path = os.path.join(download_dir, version_or_filename)

try:
self._s3.download_file(self._bucket, s3_key, local_path)
return [local_path]
except Exception as download_error:
error_msg = str(download_error).lower()
if "404" in error_msg or "not found" in error_msg:
# Neither folder nor file exists - return empty list
return []
raise
except RuntimeError:
raise
except Exception as e:
raise RuntimeError(f"Failed to pull artifacts from S3: {e}") from e

Expand Down
14 changes: 0 additions & 14 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
import contextlib
import enum
import os
import shutil
import signal
import uuid
from abc import ABC, abstractmethod
Expand All @@ -11,11 +9,9 @@

from pydantic import BaseModel, Field, model_validator

from alphatrion import envs
from alphatrion.run.run import Run
from alphatrion.runtime.contextvars import current_exp_id
from alphatrion.runtime.runtime import global_runtime
from alphatrion.snapshot.snapshot import team_path
from alphatrion.storage import runtime as storage_runtime
from alphatrion.storage.sql_models import FINISHED_STATUS, Status
from alphatrion.types import CallableEntry, PostRunHookFn
Expand Down Expand Up @@ -212,8 +208,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
current_exp_id.reset(self._token)
self._runtime.current_experiment = None

self._cleanup_files()

def _start(
self,
name: str,
Expand Down Expand Up @@ -498,14 +492,6 @@ def start(
) -> "Experiment":
raise NotImplementedError

def _cleanup_files(self):
# remove the whole folder once the experiment is done.
if (
os.path.exists(team_path())
and os.getenv(envs.AUTO_CLEANUP, "true").lower() == "true"
):
shutil.rmtree(team_path(), ignore_errors=True)

def _start_signal_handlers(self):
loop = asyncio.get_running_loop()

Expand Down
8 changes: 4 additions & 4 deletions alphatrion/log/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ async def load_checkpoint(
output_dir: str | None = None,
) -> list[str]:
"""
Load checkpoint from artifact registry.
Load checkpoint from artifact registry, the path is expected to be in the format of
"org_id/team_id/exp_id/ckpt/".

:param id: the id of the experiment.
:param version: the version of the checkpoint to load, default is "latest".
For oci backend, version is the tag of the artifact.
For s3 backend, version is the name of the file to load.
If version is "latest", the most recently modified file will be loaded.
If version is "latest", it will load the latest version (for oci backend) or
the file with the latest timestamp (for s3 backend).
:param type: the type of the checkpoint, can be "experiment" or "agent", default is "experiment".
:param output_dir: the directory to which the checkpoint will be loaded.
"""
Expand Down
2 changes: 1 addition & 1 deletion alphatrion/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def log_artifact(
Takes no arguments and returns nothing. This allows side effects after the artifact is saved,
such as logging or cleanup.

:return: the path of the logged artifact.
:return: the path of the logged artifact or None if no artifact was logged (e.g., if paths is empty).
OCI format: {org_id}/{team_id}/{exp_id}/{repo_name}:{version}
S3 format: {org_id}/{team_id}/{exp_id}/{repo_name}/{version}
"""
Expand Down
4 changes: 2 additions & 2 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ async def list_artifact_files(
arf = runtime.storage_runtime().artifact
org_id = info.context.org_id
file_paths = arf.pull(
repo_name=f"{org_id}/{team_id}/{repo_name}", version=tag
repo_name=f"{org_id}/{team_id}/{repo_name}", version_or_filename=tag
)

if not file_paths:
Expand Down Expand Up @@ -1010,7 +1010,7 @@ async def get_artifact_content(
# Pull the artifact - ORAS will manage temp directory
# Returns absolute paths to files in ORAS temp directory
file_paths = arf.pull(
repo_name=f"{org_id}/{team_id}/{repo_name}", version=tag
repo_name=f"{org_id}/{team_id}/{repo_name}", version_or_filename=tag
)

if not file_paths:
Expand Down
Empty file removed alphatrion/snapshot/__init__.py
Empty file.
61 changes: 0 additions & 61 deletions alphatrion/snapshot/snapshot.py

This file was deleted.

6 changes: 3 additions & 3 deletions tests/integration/test_oci_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_oci_backend_pull_single_file(artifact, unique_repo):
# Pull the file
output_dir = os.path.join(tmpdir, "download")
result = artifact.pull(
repo_name=unique_repo, version="v1", output_dir=output_dir
repo_name=unique_repo, version_or_filename="v1", output_dir=output_dir
)

# Verify file was downloaded
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_oci_backend_pull_multiple_files(artifact, unique_repo):
# Pull the files
output_dir = os.path.join(tmpdir, "download")
result = artifact.pull(
repo_name=unique_repo, version="v1", output_dir=output_dir
repo_name=unique_repo, version_or_filename="v1", output_dir=output_dir
)

# Verify all files were downloaded
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_oci_backend_pull_to_current_dir(artifact, unique_repo):
artifact.push(repo_name=unique_repo, paths=test_file, version="v1")

# Pull without output_dir
result = artifact.pull(repo_name=unique_repo, version="v1")
result = artifact.pull(repo_name=unique_repo, version_or_filename="v1")

# Should download to current directory
assert len(result) == 1
Expand Down
22 changes: 11 additions & 11 deletions tests/unit/artifact/test_s3_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_s3_backend_pull_single_file(s3_client):
output_dir = os.path.join(tmpdir, "download")
result = artifact.pull(
repo_name="org123/team456/exp1/ckpt",
version="checkpoint.pt",
version_or_filename="checkpoint.pt",
output_dir=output_dir,
)

Expand Down Expand Up @@ -351,7 +351,7 @@ def test_s3_backend_pull_version_folder(s3_client):
# Pull the version folder
output_dir = os.path.join(tmpdir, "download")
result = artifact.pull(
repo_name="org123/team456/exp1/ckpt", version="v1", output_dir=output_dir
repo_name="org123/team456/exp1/ckpt", version_or_filename="v1", output_dir=output_dir
)

# Verify all files were downloaded
Expand Down Expand Up @@ -385,7 +385,7 @@ def test_s3_backend_pull_to_current_dir(s3_client):

# Pull without output_dir
result = artifact.pull(
repo_name="org123/team456/test-repo", version="test.txt"
repo_name="org123/team456/test-repo", version_or_filename="test.txt"
)

# Should download to current directory
Expand All @@ -400,18 +400,18 @@ def test_s3_backend_pull_to_current_dir(s3_client):


def test_s3_backend_pull_nonexistent_file(s3_client):
"""Test pull with non-existent file raises error."""
"""Test pull with non-existent file returns empty list."""
from alphatrion.artifact.artifact import Artifact

artifact = Artifact()

with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(RuntimeError, match="Failed to pull artifacts"):
artifact.pull(
repo_name="org123/team456/nonexistent",
version="missing.txt",
output_dir=tmpdir,
)
result = artifact.pull(
repo_name="org123/team456/nonexistent",
version_or_filename="missing.txt",
output_dir=tmpdir,
)
assert result == []


def test_s3_backend_pull_empty_version_folder(s3_client):
Expand All @@ -423,7 +423,7 @@ def test_s3_backend_pull_empty_version_folder(s3_client):
with tempfile.TemporaryDirectory() as tmpdir:
# Pull non-existent version folder
result = artifact.pull(
repo_name="org123/team456/exp1/ckpt", version="v999", output_dir=tmpdir
repo_name="org123/team456/exp1/ckpt", version_or_filename="v999", output_dir=tmpdir
)

# Should return empty list
Expand Down
Loading
Loading