Skip to content

Commit

Permalink
Merge pull request #378 from allenai/shanea/storage-cleaner-cached-path
Browse files Browse the repository at this point in the history
[Storage Cleaner] Migrate to cached_path
  • Loading branch information
2015aroras committed Dec 7, 2023
2 parents 22cefa2 + a22cddc commit 1dbc346
Showing 1 changed file with 61 additions and 65 deletions.
126 changes: 61 additions & 65 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import argparse
import logging
import os
import re
import shutil
import tarfile
import tempfile
from abc import ABC, abstractmethod
from argparse import ArgumentParser, _SubParsersAction
Expand All @@ -13,10 +11,12 @@
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse

import boto3.session
import botocore.exceptions as boto_exceptions
import google.cloud.storage as gcs
from cached_path import add_scheme_client, cached_path, set_cache_dir
from cached_path.schemes import S3Client
from google.api_core.exceptions import NotFound
from rich.progress import Progress

from olmo import util
from olmo.aliases import PathOrStr
Expand All @@ -31,11 +31,11 @@ class CleaningOperations(Enum):
DELETE_BAD_RUNS = auto()


class StorageType(Enum):
LOCAL_FS = auto()
GCS = auto()
S3 = auto()
R2 = auto()
class StorageType(util.StrEnum):
LOCAL_FS = ""
GCS = "gs"
S3 = "s3"
R2 = "r2"


class StorageAdapter(ABC):
Expand Down Expand Up @@ -80,15 +80,8 @@ def create_storage_adapter(cls, storage_type: StorageType):
return LocalFileSystemAdapter()
if storage_type == StorageType.GCS:
return GoogleCloudStorageAdapter()
if storage_type == StorageType.S3:
if storage_type in (StorageType.S3, StorageType.R2):
return S3StorageAdapter(storage_type)
if storage_type == StorageType.R2:
r2_account_id = os.environ.get("R2_ACCOUNT_ID")
if r2_account_id is None:
raise ValueError(
"R2_ACCOUNT_ID environment variable not set with R2 account id, cannot connect to R2"
)
return S3StorageAdapter(storage_type, endpoint_url=f"https://{r2_account_id}.r2.cloudflarestorage.com")

raise NotImplementedError(f"No storage adapter implemented for storage type {storage_type}")

Expand Down Expand Up @@ -136,11 +129,17 @@ def has_supported_archive_extension(self, path: PathOrStr) -> bool:
def _list_entries(
self, path: PathOrStr, include_files: bool = True, max_file_size: Optional[int] = None
) -> List[str]:
path = Path(path)
if path.is_dir():
path_obj = Path(path)
if path_obj.is_file():
if not self.has_supported_archive_extension(path_obj):
raise ValueError(f"File does not have a supported archive extension: {path}")

path_obj = cached_path(path_obj, extract_archive=True)

if path_obj.is_dir():
return [
entry.name
for entry in path.iterdir()
for entry in path_obj.iterdir()
if (
(include_files or not entry.is_file())
and (
Expand All @@ -149,16 +148,6 @@ def _list_entries(
)
]

if self.has_supported_archive_extension(path):
if not include_files or max_file_size is not None:
raise NotImplementedError("Filtering out entries from a tar file is not yet supported")

with tarfile.open(path) as tar:
log.info("Listing entries from archive %s", path)
return [
Path(tar_subpath).name for tar_subpath in tar.getnames() if len(Path(tar_subpath).parts) == 2
]

raise ValueError(f"Path does not correspond to directory or supported archive file: {path}")

def list_entries(self, path: str, max_file_size: Optional[int] = None) -> List[str]:
Expand Down Expand Up @@ -252,17 +241,6 @@ def _get_size(self, bucket_name: str, key: str) -> int:

return self._get_blob_size(blob)

def _download_file(self, bucket_name: str, key: str) -> str:
extension = "".join(Path(key).suffixes)
temp_file = self.local_fs_adapter.create_temp_file(suffix=extension)

bucket = self.gcs_client.bucket(bucket_name)
blob = bucket.get_blob(key)
if blob is None:
raise ValueError(f"Downloading invalid object: {self._get_path(bucket_name, key)}")
blob.download_to_filename(temp_file)
return temp_file

def _get_directory_entries(
self,
bucket_name: str,
Expand Down Expand Up @@ -304,8 +282,7 @@ def _list_entries(
bucket_name, key = self._get_bucket_name_and_key(path)

if self.local_fs_adapter.has_supported_archive_extension(path):
log.info("Downloading archive %s", path)
file_path = self._download_file(bucket_name, key)
file_path = str(cached_path(path, extract_archive=True))

if not include_files:
return self.local_fs_adapter.list_dirs(file_path)
Expand Down Expand Up @@ -354,10 +331,10 @@ def is_dir(self, path: str) -> bool:


class S3StorageAdapter(StorageAdapter):
def __init__(self, storage_type: StorageType, endpoint_url: Optional[str] = None):
def __init__(self, storage_type: StorageType):
super().__init__()
self._storage_type = storage_type
self._s3_client = util._get_s3_client(endpoint_url=endpoint_url)
self._s3_client = util._get_s3_client(str(storage_type))

self._local_fs_adapter: Optional[LocalFileSystemAdapter] = None
self._temp_dirs: List[tempfile.TemporaryDirectory] = []
Expand Down Expand Up @@ -396,25 +373,6 @@ def _get_size(self, bucket_name: str, key: str) -> int:
raise RuntimeError(f"Failed to get size for file: {self._get_path(bucket_name, key)}")
return head_response["ContentLength"]

def _download_file(self, bucket_name: str, key: str) -> str:
extension = "".join(Path(key).suffixes)
temp_file = self.local_fs_adapter.create_temp_file(suffix=extension)

size_in_bytes = self._get_size(bucket_name, key)

with Progress(transient=True) as progress:
download_task = progress.add_task(f"Downloading {key}", total=size_in_bytes)

def progress_callback(bytes_downloaded: int):
progress.update(download_task, advance=bytes_downloaded)

self._s3_client.download_file(bucket_name, key, temp_file, Callback=progress_callback)

if not self.local_fs_adapter.is_file(temp_file):
raise RuntimeError(f"Failed to download file: {self._get_path(bucket_name, key)}")

return temp_file

def _get_directory_entries(
self,
bucket_name: str,
Expand Down Expand Up @@ -454,8 +412,7 @@ def _list_entries(
bucket_name, key = self._get_bucket_name_and_key(path)

if self.local_fs_adapter.has_supported_archive_extension(path):
log.info("Downloading archive %s", path)
file_path = self._download_file(bucket_name, key)
file_path = str(cached_path(path, extract_archive=True))

if not include_files:
return self.local_fs_adapter.list_dirs(file_path)
Expand Down Expand Up @@ -672,6 +629,40 @@ def perform_operation(args: argparse.Namespace):
raise NotImplementedError(args.op)


def _add_cached_path_s3_client():
class S3SchemeClient(S3Client):
"""
A class that the `cached_path` module can use to retrieve resources from
S3 (and R2, which is S3-based). Refer to
[cached_path docs](https://github.com/allenai/cached_path/blob/main/docs/source/overview.md#supported-url-schemes).
"""

# This is used by cached_path to get the schemes are handled by this client
scheme = ("s3", "r2")

def __init__(self, resource: str) -> None:
super().__init__(resource)
parsed_path = urlparse(resource)
bucket_name = parsed_path.netloc
key = parsed_path.path.lstrip("/")

profile_name = util._get_s3_profile_name(parsed_path.scheme)
endpoint_url = util._get_s3_endpoint_url(parsed_path.scheme)

session = boto3.session.Session(profile_name=profile_name)
s3_resource = session.resource("s3", endpoint_url=endpoint_url)
self.s3_object = s3_resource.Object(bucket_name, key) # type: ignore

add_scheme_client(S3SchemeClient)


def _setup_cached_path(args: argparse.Namespace):
if args.temp_dir is not None:
set_cache_dir(args.temp_dir)

_add_cached_path_s3_client()


def _add_delete_subparser(subparsers: _SubParsersAction):
delete_runs_parser: ArgumentParser = subparsers.add_parser(
"clean", help="Delete bad runs (e.g. runs with no non-trivial checkpoints)"
Expand Down Expand Up @@ -712,6 +703,10 @@ def get_parser() -> ArgumentParser:
action="store_true",
help="If set, indicate actions but do not do them",
)
parser.add_argument(
"--temp_dir",
help="Directory where artifacts (e.g. unarchived directories) can be stored temporarily",
)

subparsers = parser.add_subparsers(dest="command", help="Cleaning commands", required=True)
_add_delete_subparser(subparsers)
Expand All @@ -723,6 +718,7 @@ def main():
args = get_parser().parse_args()

util.prepare_cli_environment()
_setup_cached_path(args)
perform_operation(args)


Expand Down

0 comments on commit 1dbc346

Please sign in to comment.