diff --git a/deeplake/core/storage/s3.py b/deeplake/core/storage/s3.py index 16229add00..abeb704186 100644 --- a/deeplake/core/storage/s3.py +++ b/deeplake/core/storage/s3.py @@ -5,6 +5,7 @@ import boto3 import botocore # type: ignore import posixpath +import ssl from typing import Dict, Optional, Tuple, Type from datetime import datetime, timezone from botocore.session import ComponentLocator @@ -41,6 +42,7 @@ EndpointConnectionError, IncompleteReadError, SSLError, + ssl.SSLError, ) try: @@ -190,8 +192,10 @@ def __setitem__(self, path, content): self._set(path, content) except CONNECTION_ERRORS as err: tries = self.num_tries + retry_wait = 0 for i in range(1, tries + 1): always_warn(f"Encountered connection error, retry {i} out of {tries}") + retry_wait = self._retry_wait_and_extend(retry_wait, err) try: self._set(path, content) always_warn( @@ -280,8 +284,11 @@ def get_bytes( return self._get_bytes(path, start_byte, end_byte) except CONNECTION_ERRORS as err: tries = self.num_tries + retry_wait = 0 for i in range(1, tries + 1): always_warn(f"Encountered connection error, retry {i} out of {tries}") + retry_wait = self._retry_wait_and_extend(retry_wait, err) + try: ret = self._get_bytes(path, start_byte, end_byte) always_warn( @@ -322,8 +329,11 @@ def __delitem__(self, path): self._del(path) except CONNECTION_ERRORS as err: tries = self.num_tries + retry_wait = 0 for i in range(1, tries + 1): always_warn(f"Encountered connection error, retry {i} out of {tries}") + retry_wait = self._retry_wait_and_extend(retry_wait, err) + try: self._del(path) always_warn( @@ -338,7 +348,7 @@ def __delitem__(self, path): @property def num_tries(self): - return min(ceil((time.time() - self.start_time) / 300), 5) + return max(3, min(ceil((time.time() - self.start_time) / 300), 5)) def _keys_iterator(self): self._check_update_creds() @@ -625,11 +635,45 @@ def get_presigned_url(self, key, full=False): self._presigned_urls[path] = (url, time.time()) return url + def _get_object_size(self, path: str) -> int: + obj = self.resource.Object(self.bucket, path) + return obj.content_length + def get_object_size(self, path: str) -> int: self._check_update_creds() path = "".join((self.path, path)) - obj = self.resource.Object(self.bucket, path) - return obj.content_length + + try: + return self._get_object_size(path) + except botocore.exceptions.ClientError as err: + if err.response["Error"]["Code"] == "NoSuchKey": + raise KeyError(err) from err + if err.response["Error"]["Code"] == "InvalidAccessKeyId": + new_error_cls: Type[S3GetError] = S3GetAccessError + else: + new_error_cls = S3GetError + with S3ResetReloadCredentialsManager(self, new_error_cls): + return self._get_object_size(path) + except CONNECTION_ERRORS as err: + tries = self.num_tries + retry_wait = 0 + for i in range(1, tries + 1): + always_warn(f"Encountered connection error, retry {i} out of {tries}") + retry_wait = self._retry_wait_and_extend(retry_wait, err) + + try: + ret = self._get_object_size(path) + always_warn( + f"Connection re-established after {i} {['retries', 'retry'][i==1]}." + ) + return ret + except Exception: + pass + raise S3GetError(err) from err + except botocore.exceptions.NoCredentialsError as err: + raise S3GetAccessError from err + except Exception as err: + raise S3GetError(err) from err def get_object_from_full_url(self, url: str): root = url.replace("s3://", "") @@ -645,8 +689,11 @@ def get_object_from_full_url(self, url: str): return self._get(path, bucket) except CONNECTION_ERRORS as err: tries = self.num_tries + retry_wait = 0 for i in range(1, tries + 1): always_warn(f"Encountered connection error, retry {i} out of {tries}") + retry_wait = self._retry_wait_and_extend(retry_wait, err) + try: ret = self._get(path, bucket) always_warn( @@ -685,8 +732,11 @@ def set_items(self, items: dict): self._set_items(items) except CONNECTION_ERRORS as err: tries = self.num_tries + retry_wait = 0 for i in range(1, tries + 1): always_warn(f"Encountered connection error, retry {i} out of {tries}") + retry_wait = self._retry_wait_and_extend(retry_wait, err) + try: self._set_items(items) always_warn( @@ -713,3 +763,13 @@ def get_items(self, keys): yield key, future.result() else: yield key, exception + + def _retry_wait_and_extend(self, retry_wait: int, err: Exception): + if not (isinstance(err, ssl.SSLError) or isinstance(err, SSLError)): + return 0 + + time.sleep(retry_wait) + + if retry_wait == 0: + return 1 + return retry_wait * 2 diff --git a/deeplake/core/storage/tests/test_storage_provider.py b/deeplake/core/storage/tests/test_storage_provider.py index b14c325cf1..7144f941d8 100644 --- a/deeplake/core/storage/tests/test_storage_provider.py +++ b/deeplake/core/storage/tests/test_storage_provider.py @@ -1,4 +1,9 @@ import json +import ssl +import time +from unittest.mock import patch + +from deeplake.core import S3Provider from deeplake.tests.path_fixtures import gcs_creds from deeplake.tests.common import is_opt_true from deeplake.tests.storage_fixtures import ( @@ -229,3 +234,22 @@ def test_azure_empty_blob(azure_storage): azure_storage.get_object_from_full_url(f"{azure_storage.root}/empty_blob") == b"" ) + + +@pytest.mark.slow +def test_s3_backoff(): + runs = 0 + s3 = S3Provider("s3://mock") + + def fake_set_items(items: dict): + nonlocal runs + runs = runs + 1 + + raise ssl.SSLError + + start = time.time() + with patch("deeplake.core.storage.s3.S3Provider._set_items", wraps=fake_set_items): + with pytest.raises(Exception): + s3.set_items({"test": "test"}) + assert runs == 4 + assert 3 < time.time() - start < 5