Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix occasional KeyError in S3 logic #301

Merged
merged 9 commits into from
Sep 30, 2023
8 changes: 7 additions & 1 deletion olmo/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["OlmoError", "OlmoConfigurationError", "OlmoCliError"]
__all__ = ["OlmoError", "OlmoConfigurationError", "OlmoCliError", "OlmoNetworkError"]


class OlmoError(Exception):
Expand All @@ -17,3 +17,9 @@ class OlmoCliError(OlmoError):
"""
An error from incorrect CLI usage.
"""


class OlmoNetworkError(OlmoError):
"""
An error with a network request.
"""
112 changes: 80 additions & 32 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,23 @@
from typing import Any, Callable, Dict, Optional, TypeVar, Union

import boto3
import botocore.exceptions as boto_exceptions
import rich
import torch
import torch.distributed as dist
import torch.nn as nn
from botocore.config import Config
from rich.console import Console, ConsoleRenderable
from rich.highlighter import NullHighlighter
from rich.text import Text
from rich.traceback import Traceback

from .aliases import PathOrStr
from .config import LogFilterType
from .exceptions import OlmoCliError, OlmoError
from .exceptions import OlmoCliError, OlmoError, OlmoNetworkError

_log_extra_fields: Dict[str, Any] = {}
log = logging.getLogger(__name__)


def log_extra_field(field_name: str, field_value: Any) -> None:
Expand Down Expand Up @@ -127,9 +130,7 @@ def excepthook(exctype, value, traceback):
elif issubclass(exctype, OlmoError):
rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
else:
logging.getLogger().critical(
"Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback)
)
log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))


def install_excepthook():
Expand Down Expand Up @@ -468,44 +469,91 @@ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool =
blob.upload_from_filename(source)


s3_client = boto3.client("s3")
s3_client = boto3.client("s3", config=Config(retries={"max_attempts": 10, "mode": "standard"}))


def _wait_before_retry(attempt: int):
time.sleep(min(0.5 * 2**attempt, 3.0))

def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
from botocore.exceptions import ClientError

def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3):
err: Optional[Exception] = None
if not save_overwrite:
try:
s3_client.head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
s3_client.upload_file(source, bucket_name, key)
for attempt in range(1, max_attempts + 1):
try:
s3_client.head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(
f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
err = None
break
err = e

if attempt < max_attempts:
log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
_wait_before_retry(attempt)

if err is not None:
raise OlmoNetworkError("Failed to check object existence during s3 upload") from err

try:
s3_client.upload_file(source, bucket_name, key)
except boto_exceptions.ClientError as e:
raise OlmoNetworkError("Failed to upload to s3") from e

def _s3_file_size(bucket_name: str, key: str) -> int:
from botocore.exceptions import ClientError

try:
return s3_client.head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
raise FileNotFoundError(f"s3://{bucket_name}/{key}")
def _s3_file_size(bucket_name: str, key: str, max_attempts: int = 3) -> int:
err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return s3_client.head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
err = e

if attempt < max_attempts:
log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
_wait_before_retry(attempt)

def _s3_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
from botocore.exceptions import ClientError
raise OlmoNetworkError("Failed to get s3 file size") from err

try:
return s3_client.get_object(
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
)["Body"].read()
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
raise FileNotFoundError(f"s3://{bucket_name}/{key}")

def _s3_get_bytes_range(
bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
) -> bytes:
err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return s3_client.get_object(
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
)["Body"].read()
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
err = e
2015aroras marked this conversation as resolved.
Show resolved Hide resolved
except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
# ResponseStreamingError (subclass of HTTPClientError) can happen as
# a result of a failed read from the stream (http.client.IncompleteRead).
# Retrying can help in this case.
err = e
2015aroras marked this conversation as resolved.
Show resolved Hide resolved

if attempt < max_attempts:
log.warning(
"%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
)
_wait_before_retry(attempt)

# When torch's DataLoader intercepts exceptions, it may try to re-raise them
# by recalling their constructor with a single message arg. Torch has some
# logic to deal with the absence of a single-parameter constructor, but it
# doesn't gracefully handle other possible failures in calling such a constructor
# This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
# in us losing the true exception info. To avoid this, we change the exception
# to a type that has a single-parameter constructor.
raise OlmoNetworkError("Failed to get bytes range from s3") from err


def is_weight_decay_module(module: nn.Module) -> bool:
Expand Down
Loading