Skip to content

Commit

Permalink
Merge pull request #301 from allenai/shanea/fix-s3-keyerror-failures
Browse files Browse the repository at this point in the history
Fix occasional KeyError in S3 logic
  • Loading branch information
2015aroras committed Sep 30, 2023
2 parents e7b92a6 + 019ecf3 commit 0a1455b
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 33 deletions.
8 changes: 7 additions & 1 deletion olmo/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["OlmoError", "OlmoConfigurationError", "OlmoCliError"]
__all__ = ["OlmoError", "OlmoConfigurationError", "OlmoCliError", "OlmoNetworkError"]


class OlmoError(Exception):
Expand All @@ -17,3 +17,9 @@ class OlmoCliError(OlmoError):
"""
An error from incorrect CLI usage.
"""


class OlmoNetworkError(OlmoError):
"""
An error with a network request.
"""
112 changes: 80 additions & 32 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,23 @@
from typing import Any, Callable, Dict, Optional, TypeVar, Union

import boto3
import botocore.exceptions as boto_exceptions
import rich
import torch
import torch.distributed as dist
import torch.nn as nn
from botocore.config import Config
from rich.console import Console, ConsoleRenderable
from rich.highlighter import NullHighlighter
from rich.text import Text
from rich.traceback import Traceback

from .aliases import PathOrStr
from .config import LogFilterType
from .exceptions import OlmoCliError, OlmoError
from .exceptions import OlmoCliError, OlmoError, OlmoNetworkError

_log_extra_fields: Dict[str, Any] = {}
log = logging.getLogger(__name__)


def log_extra_field(field_name: str, field_value: Any) -> None:
Expand Down Expand Up @@ -127,9 +130,7 @@ def excepthook(exctype, value, traceback):
elif issubclass(exctype, OlmoError):
rich.get_console().print(Text(f"{exctype.__name__}:", style="red"), value, highlight=False)
else:
logging.getLogger().critical(
"Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback)
)
log.critical("Uncaught %s: %s", exctype.__name__, value, exc_info=(exctype, value, traceback))


def install_excepthook():
Expand Down Expand Up @@ -468,44 +469,91 @@ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool =
blob.upload_from_filename(source)


s3_client = boto3.client("s3")
s3_client = boto3.client("s3", config=Config(retries={"max_attempts": 10, "mode": "standard"}))


def _wait_before_retry(attempt: int):
time.sleep(min(0.5 * 2**attempt, 3.0))

def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
from botocore.exceptions import ClientError

def _s3_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False, max_attempts: int = 3):
err: Optional[Exception] = None
if not save_overwrite:
try:
s3_client.head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it.")
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
s3_client.upload_file(source, bucket_name, key)
for attempt in range(1, max_attempts + 1):
try:
s3_client.head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(
f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
err = None
break
err = e

if attempt < max_attempts:
log.warning("%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err)
_wait_before_retry(attempt)

if err is not None:
raise OlmoNetworkError("Failed to check object existence during s3 upload") from err

try:
s3_client.upload_file(source, bucket_name, key)
except boto_exceptions.ClientError as e:
raise OlmoNetworkError("Failed to upload to s3") from e

def _s3_file_size(bucket_name: str, key: str) -> int:
from botocore.exceptions import ClientError

try:
return s3_client.head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
raise FileNotFoundError(f"s3://{bucket_name}/{key}")
def _s3_file_size(bucket_name: str, key: str, max_attempts: int = 3) -> int:
err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return s3_client.head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
err = e

if attempt < max_attempts:
log.warning("%s failed attempt %d with retriable error: %s", _s3_file_size.__name__, attempt, err)
_wait_before_retry(attempt)

def _s3_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
from botocore.exceptions import ClientError
raise OlmoNetworkError("Failed to get s3 file size") from err

try:
return s3_client.get_object(
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
)["Body"].read()
except ClientError as e:
if int(e.response["Error"]["Code"]) != 404:
raise
raise FileNotFoundError(f"s3://{bucket_name}/{key}")

def _s3_get_bytes_range(
bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
) -> bytes:
err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return s3_client.get_object(
Bucket=bucket_name, Key=key, Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}"
)["Body"].read()
except boto_exceptions.ClientError as e:
if int(e.response["Error"]["Code"]) == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
err = e
except (boto_exceptions.HTTPClientError, boto_exceptions.ConnectionError) as e:
# ResponseStreamingError (subclass of HTTPClientError) can happen as
# a result of a failed read from the stream (http.client.IncompleteRead).
# Retrying can help in this case.
err = e

if attempt < max_attempts:
log.warning(
"%s failed attempt %d with retriable error: %s", _s3_get_bytes_range.__name__, attempt, err
)
_wait_before_retry(attempt)

# When torch's DataLoader intercepts exceptions, it may try to re-raise them
# by recalling their constructor with a single message arg. Torch has some
# logic to deal with the absence of a single-parameter constructor, but it
# doesn't gracefully handle other possible failures in calling such a constructor
# This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
# in us losing the true exception info. To avoid this, we change the exception
# to a type that has a single-parameter constructor.
raise OlmoNetworkError("Failed to get bytes range from s3") from err


def is_weight_decay_module(module: nn.Module) -> bool:
Expand Down

0 comments on commit 0a1455b

Please sign in to comment.