Skip to content

Commit

Permalink
Merge pull request #389 from allenai/shanea/add-r2-scheme
Browse files Browse the repository at this point in the history
Add support for R2 URLs
  • Loading branch information
2015aroras authored Dec 5, 2023
2 parents ac01778 + 1194989 commit 22cefa2
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 23 deletions.
10 changes: 9 additions & 1 deletion olmo/data/memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
from torch.utils.data import Dataset

from olmo.exceptions import OlmoEnvironmentError

from ..aliases import PathOrStr
from ..util import _get_s3_client, file_size, get_bytes_range

Expand Down Expand Up @@ -71,7 +73,13 @@ def max_seq_len(self) -> int:
@property
def offsets(self) -> List[Tuple[int, int]]:
# Create the global S3 client up front to work around a threading issue in boto.
_get_s3_client()
_get_s3_client("s3")
try:
_get_s3_client("r2")
except OlmoEnvironmentError:
# R2 might not be needed, so ignore this error. We will get an error
# later if R2 is needed.
pass

if self._mmap_offsets is None:
import concurrent.futures
Expand Down
8 changes: 7 additions & 1 deletion olmo/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["OlmoError", "OlmoConfigurationError", "OlmoCliError", "OlmoNetworkError"]
__all__ = ["OlmoError", "OlmoConfigurationError", "OlmoCliError", "OlmoEnvironmentError", "OlmoNetworkError"]


class OlmoError(Exception):
Expand All @@ -19,6 +19,12 @@ class OlmoCliError(OlmoError):
"""


class OlmoEnvironmentError(OlmoError):
"""
An error from incorrect environment variables.
"""


class OlmoNetworkError(OlmoError):
"""
An error with a network request.
Expand Down
84 changes: 63 additions & 21 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from rich.traceback import Traceback

from .aliases import PathOrStr
from .exceptions import OlmoCliError, OlmoError, OlmoNetworkError, OlmoThreadError
from .exceptions import (
OlmoCliError,
OlmoEnvironmentError,
OlmoError,
OlmoNetworkError,
OlmoThreadError,
)
from .torch_util import get_global_rank, get_local_rank, get_node_rank, is_distributed


Expand Down Expand Up @@ -325,8 +331,8 @@ def file_size(path: PathOrStr) -> int:
parsed = urlparse(str(path))
if parsed.scheme == "gs":
return _gcs_file_size(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme == "s3":
return _s3_file_size(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme in ("s3", "r2"):
return _s3_file_size(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme == "file":
return file_size(str(path).replace("file://", "", 1))
else:
Expand All @@ -344,8 +350,8 @@ def upload(source: PathOrStr, target: str, save_overwrite: bool = False):
parsed = urlparse(target)
if parsed.scheme == "gs":
_gcs_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
elif parsed.scheme == "s3":
_s3_upload(source, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
elif parsed.scheme in ("s3", "r2"):
_s3_upload(source, parsed.scheme, parsed.netloc, parsed.path.strip("/"), save_overwrite=save_overwrite)
else:
raise NotImplementedError(f"Upload not implemented for '{parsed.scheme}' scheme")

Expand All @@ -357,8 +363,10 @@ def get_bytes_range(source: PathOrStr, bytes_start: int, num_bytes: int) -> byte
parsed = urlparse(str(source))
if parsed.scheme == "gs":
return _gcs_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
elif parsed.scheme == "s3":
return _s3_get_bytes_range(parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes)
elif parsed.scheme in ("s3", "r2"):
return _s3_get_bytes_range(
parsed.scheme, parsed.netloc, parsed.path.strip("/"), bytes_start, num_bytes
)
elif parsed.scheme == "file":
return get_bytes_range(str(source).replace("file://", "", 1), bytes_start, num_bytes)
else:
Expand All @@ -376,8 +384,8 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
parsed = urlparse(str(dir))
if parsed.scheme == "gs":
raise NotImplementedError
elif parsed.scheme == "s3":
return _s3_find_latest_checkpoint(parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme in ("s3", "r2"):
return _s3_find_latest_checkpoint(parsed.scheme, parsed.netloc, parsed.path.strip("/"))
elif parsed.scheme == "file":
return find_latest_checkpoint(str(dir).replace("file://", "", 1))
else:
Expand Down Expand Up @@ -438,11 +446,43 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes
return blob.download_as_bytes(start=bytes_start, end=bytes_start + num_bytes - 1)


def _get_s3_profile_name(scheme: str) -> Optional[str]:
if scheme == "s3":
# For backwards compatibility, we assume S3 uses the default profile if S3_PROFILE is not set.
return os.environ.get("S3_PROFILE")
if scheme == "r2":
profile_name = os.environ.get("R2_PROFILE")
if profile_name is None:
raise OlmoEnvironmentError(
"R2 profile name is not set. Did you forget to set the 'R2_PROFILE' env var?"
)

return profile_name

raise NotImplementedError(f"Cannot get profile name for scheme {scheme}")


def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
if scheme == "s3":
return None
if scheme == "r2":
r2_endpoint_url = os.environ.get("R2_ENDPOINT_URL")
if r2_endpoint_url is None:
raise OlmoEnvironmentError(
"R2 endpoint url is not set. Did you forget to set the 'R2_ENDPOINT_URL' env var?"
)

return r2_endpoint_url

raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")


@cache
def _get_s3_client(endpoint_url: Optional[str] = None):
return boto3.client(
def _get_s3_client(scheme: str):
session = boto3.Session(profile_name=_get_s3_profile_name(scheme))
return session.client(
"s3",
endpoint_url=endpoint_url,
endpoint_url=_get_s3_endpoint_url(scheme),
config=Config(retries={"max_attempts": 10, "mode": "standard"}),
use_ssl=not int(os.environ.get("OLMO_NO_SSL", "0")),
)
Expand All @@ -452,12 +492,14 @@ def _wait_before_retry(attempt: int):
time.sleep(min(0.5 * 2**attempt, 3.0))


def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3):
def _s3_upload(
source: Path, scheme: str, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3
):
err: Optional[Exception] = None
if not save_overwrite:
for attempt in range(1, max_attempts + 1):
try:
_get_s3_client().head_object(Bucket=bucket_name, Key=key)
_get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(
f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
Expand All @@ -475,16 +517,16 @@ def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool =
raise OlmoNetworkError("Failed to check object existence during s3 upload") from err

try:
_get_s3_client().upload_file(source, bucket_name, key)
_get_s3_client(scheme).upload_file(source, bucket_name, key)
except boto_exceptions.ClientError as e:
raise OlmoNetworkError("Failed to upload to s3") from e


def _s3_file_size(bucket_name: str, key: str, max_attempts: int = 3) -> int:
def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return _get_s3_client().head_object(Bucket=bucket_name, Key=key)["ContentLength"]
return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
Expand All @@ -498,13 +540,13 @@ def _s3_file_size(bucket_name: str, key: str, max_attempts: int = 3) -> int:


def _s3_get_bytes_range(
bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
) -> bytes:
err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return (
_get_s3_client()
_get_s3_client(scheme)
.get_object(
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
)["Body"]
Expand Down Expand Up @@ -536,10 +578,10 @@ def _s3_get_bytes_range(
raise OlmoNetworkError("Failed to get bytes range from s3") from err


def _s3_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]:
def _s3_find_latest_checkpoint(scheme: str, bucket_name: str, prefix: str) -> Optional[str]:
if not prefix.endswith("/"):
prefix = f"{prefix}/"
response = _get_s3_client().list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
assert not response["IsTruncated"] # need to handle this if it happens
latest_step = 0
latest_checkpoint: Optional[str] = None
Expand Down

0 comments on commit 22cefa2

Please sign in to comment.