Skip to content

Commit

Permalink
fix: handle error with upload retry on unbound source
Browse files Browse the repository at this point in the history
  • Loading branch information
yedpodtrzitko committed Jun 14, 2023
1 parent 5753d56 commit f7c89b4
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 61 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
* Circular symlinks no longer cause infinite loops when syncing a folder
* Fix crash on upload retry with unbound data source

### Infrastructure
* Replaced `pyflakes` with `ruff` for linting
Expand Down
4 changes: 3 additions & 1 deletion b2sdk/raw_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,6 @@ def upload_part(
input_stream,
server_side_encryption: Optional[EncryptionSetting] = None,
):
file_sim = self.file_id_to_file[file_id]
part_data = self._simulate_chunked_post(input_stream, content_length)
assert len(part_data) == content_length
if sha1_sum == HEX_DIGITS_AT_END:
Expand All @@ -1079,8 +1078,11 @@ def upload_part(
computed_sha1 = hex_sha1_of_bytes(part_data)
if sha1_sum != computed_sha1:
raise PartSha1Mismatch(file_id)

file_sim = self.file_id_to_file[file_id]
part = PartSimulator(file_sim.file_id, part_number, content_length, sha1_sum, part_data)
file_sim.add_part(part_number, part)

result = dict(
fileId=file_id,
partNumber=part_number,
Expand Down
12 changes: 8 additions & 4 deletions b2sdk/transfer/emerge/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import threading

from abc import ABCMeta, abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING

from b2sdk.encryption.setting import EncryptionSetting
from b2sdk.exception import MaxFileSizeExceeded
Expand All @@ -25,6 +25,10 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from b2sdk.transfer.emerge.planner.part_definition import UploadEmergePartDefinition
from b2sdk.transfer.emerge.planner.planner import StreamingEmergePlan


class EmergeExecutor:
def __init__(self, services):
Expand Down Expand Up @@ -164,7 +168,7 @@ def __init__(
if self.max_queue_size is not None:
self._semaphore = threading.Semaphore(self.max_queue_size)

def execute_plan(self, emerge_plan):
def execute_plan(self, emerge_plan: "StreamingEmergePlan"):
total_length = emerge_plan.get_total_length()
encryption = self.encryption

Expand Down Expand Up @@ -240,7 +244,7 @@ def execute_plan(self, emerge_plan):
response = self.services.session.finish_large_file(file_id, part_sha1_array)
return self.services.api.file_version_factory.from_api_response(response)

def _execute_step(self, execution_step):
def _execute_step(self, execution_step: "UploadPartExecutionStep"):
semaphore = self._semaphore
if semaphore is None:
return execution_step.execute()
Expand Down Expand Up @@ -541,7 +545,7 @@ class LargeFileEmergeExecutionStepFactory(BaseExecutionStepFactory):
def __init__(
self,
emerge_execution,
emerge_part,
emerge_part: "UploadEmergePartDefinition",
part_number,
large_file_id,
large_file_upload_state,
Expand Down
6 changes: 5 additions & 1 deletion b2sdk/transfer/emerge/planner/part_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@

from b2sdk.stream.chained import ChainedStream
from b2sdk.stream.range import wrap_with_range
from typing import TYPE_CHECKING

from b2sdk.utils import hex_sha1_of_unlimited_stream

if TYPE_CHECKING:
from b2sdk.transfer.emerge.unbound_write_intent import UnboundSourceBytes


class BaseEmergePartDefinition(metaclass=ABCMeta):
@abstractmethod
Expand All @@ -38,7 +42,7 @@ def get_sha1(self):


class UploadEmergePartDefinition(BaseEmergePartDefinition):
def __init__(self, upload_source, relative_offset, length):
def __init__(self, upload_source: "UnboundSourceBytes", relative_offset, length):
self.upload_source = upload_source
self.relative_offset = relative_offset
self.length = length
Expand Down
21 changes: 7 additions & 14 deletions b2sdk/transfer/emerge/unbound_write_intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import hashlib
import io
import queue
from typing import Callable, Iterator, Optional, Union
from typing import Callable, Iterator, Union

from b2sdk.transfer.emerge.exception import UnboundStreamBufferTimeout
from b2sdk.transfer.emerge.write_intent import WriteIntent
Expand Down Expand Up @@ -41,26 +41,19 @@ def __init__(
a ``release_function`` when buffer is read in full.
``release_function`` can be called from another thread.
It is called exactly once, when the read returns
an empty buffer for the first time.
It is called exactly once, when the read is concluded
and the resource is about to be released
:param data: data to be provided as a stream
:param release_function: function to be called when all the data was read
:param release_function: function to be called when resource will be released
"""
super().__init__(data)

self.already_done = False
self.release_function = release_function

def read(self, size: Optional[int] = None) -> bytes:
result = super().read(size)

is_done = len(result) == 0
if is_done and not self.already_done:
self.already_done = True
def close(self):
if not self.closed:
self.release_function()

return result
return super().close()


class UnboundSourceBytes(AbstractUploadSource):
Expand Down
48 changes: 31 additions & 17 deletions b2sdk/transfer/outbound/upload_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
######################################################################

import logging
from contextlib import ExitStack

from typing import Optional
from typing import Optional, TYPE_CHECKING

from b2sdk.encryption.setting import EncryptionMode, EncryptionSetting
from b2sdk.exception import (
Expand All @@ -29,6 +30,9 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from ...utils.typing import TypeUploadSource


class UploadManager(TransferManager, ThreadPoolMixin):
"""
Expand Down Expand Up @@ -75,7 +79,7 @@ def upload_part(
self,
bucket_id,
file_id,
part_upload_source,
part_upload_source: "TypeUploadSource",
part_number,
large_file_upload_state,
finished_parts=None,
Expand All @@ -97,7 +101,7 @@ def _upload_part(
self,
bucket_id,
file_id,
part_upload_source,
part_upload_source: "TypeUploadSource",
part_number,
large_file_upload_state,
finished_parts,
Expand Down Expand Up @@ -134,14 +138,25 @@ def _upload_part(

# Retry the upload as needed
exception_list = []
for _ in range(self.MAX_UPLOAD_ATTEMPTS):
# if another part has already had an error there's no point in
# uploading this part
if large_file_upload_state.has_error():
raise AlreadyFailed(large_file_upload_state.get_error_message())

try:
with part_upload_source.open() as part_stream:
with ExitStack() as stream_guard:
part_stream = None

def close_stream_callback(stream):
if not stream.closed:
stream.close()

for _ in range(self.MAX_UPLOAD_ATTEMPTS):
# if another part has already had an error there's no point in
# uploading this part
if large_file_upload_state.has_error():
raise AlreadyFailed(large_file_upload_state.get_error_message())

try:
# reuse the stream in case of retry
part_stream = part_stream or part_upload_source.open()
# register stream closing callback only when reading is finally concluded
stream_guard.callback(close_stream_callback, part_stream)

content_length = part_upload_source.get_content_length()
input_stream = ReadingStreamWithProgress(
part_stream, part_progress_listener, length=content_length
Expand All @@ -164,12 +179,11 @@ def _upload_part(
content_sha1 = input_stream.hash
assert content_sha1 == response['contentSha1']
return response

except B2Error as e:
if not e.should_retry_upload():
raise
exception_list.append(e)
self.account_info.clear_bucket_upload_data(bucket_id)
except B2Error as e:
if not e.should_retry_upload():
raise
exception_list.append(e)
self.account_info.clear_bucket_upload_data(bucket_id)

large_file_upload_state.set_error(str(exception_list[-1]))
raise MaxRetriesExceeded(self.MAX_UPLOAD_ATTEMPTS, exception_list)
Expand Down
15 changes: 15 additions & 0 deletions b2sdk/utils/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
######################################################################
#
# File: b2sdk/utils/typing.py
#
# Copyright 2023 Backblaze Inc. All Rights Reserved.
#
# License https://www.backblaze.com/using_b2_code.html
#
######################################################################

from typing import TypeVar

from b2sdk.transfer.outbound.upload_source import AbstractUploadSource

TypeUploadSource = TypeVar("TypeUploadSource", bound=AbstractUploadSource)
76 changes: 76 additions & 0 deletions test/unit/bucket/test_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import apiver_deps
from apiver_deps_exception import (
AlreadyFailed,
B2ConnectionError,
B2Error,
B2RequestTimeoutDuringUpload,
BucketIdNotFound,
Expand Down Expand Up @@ -50,6 +51,7 @@
from apiver_deps import FileVersion as VFileVersionInfo
from apiver_deps import B2Api
from apiver_deps import B2HttpApiConfig
from apiver_deps import B2Session
from apiver_deps import Bucket, BucketFactory
from apiver_deps import DownloadedFile
from apiver_deps import DownloadVersion
Expand Down Expand Up @@ -1729,6 +1731,80 @@ def _upload_part(self, large_file_id, part_number, part_data):
)


class TestBucketRaisingSession(TestUpload):
def get_api(self):
class B2SessionRaising(B2Session):
def __init__(self, *args, **kwargs):
self._raise_count = 0
self._raise_until = 1
super().__init__(*args, **kwargs)

def upload_part(
self,
file_id,
part_number,
content_length,
sha1_sum,
input_stream,
server_side_encryption=None
):
if self._raise_count < self._raise_until:
self._raise_count += 1
raise B2ConnectionError()
return super().upload_part(
file_id, part_number, content_length, sha1_sum, input_stream,
server_side_encryption
)

class B2ApiPatched(B2Api):
SESSION_CLASS = staticmethod(B2SessionRaising)

self.api = B2ApiPatched(
self.account_info,
cache=self.CACHE_CLASS(),
api_config=B2HttpApiConfig(_raw_api_class=self.RAW_SIMULATOR_CLASS),
)
return self.api

def test_upload_chunk_retry_stream_open(self):
assert self.api.session._raise_count == 0
data = self._make_data(self.simulator.MIN_PART_SIZE * 3)
self.bucket.upload_unbound_stream(io.BytesIO(data), 'file1')
self._check_file_contents('file1', data)
assert self.api.session._raise_count == 1

def test_upload_chunk_stream_guard_closes(self):
data = self._make_data(self.simulator.MIN_PART_SIZE * 3)
large_file_upload_state = mock.MagicMock()
large_file_upload_state.has_error.return_value = False

class TrackedUploadSourceBytes(UploadSourceBytes):
def __init__(self, *args, **kwargs):
self._close_called = 0
super().__init__(*args, **kwargs)

def open(self):
class TrackedBytesIO(io.BytesIO):
def __init__(self, parent, *args, **kwargs):
self._parent = parent
super().__init__(*args, **kwargs)

def close(self):
self._parent._close_called += 1
return super().close()

return TrackedBytesIO(self, self.data_bytes)

data_source = TrackedUploadSourceBytes(data)
assert data_source._close_called == 0
file_id = self._start_large_file('file1')
self.api.services.upload_manager.upload_part(
self.bucket_id, file_id, data_source, 1, large_file_upload_state
).result()
# one retry means two potential callback calls, but we want one only
assert data_source._close_called == 1


class TestConcatenate(TestCaseWithBucket):
def _create_remote(self, sources, file_name, encryption=None):
return self.bucket.concatenate(sources, file_name=file_name, encryption=encryption)
Expand Down
32 changes: 8 additions & 24 deletions test/unit/internal/test_unbound_write_intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,22 @@ def setUp(self) -> None:
self.mock_fun = MagicMock()
self.wrapper = IOWrapper(self.data, release_function=self.mock_fun)

def test_function_called_only_after_empty_read(self):
self.mock_fun.assert_not_called()

self.wrapper.read(1)
self.mock_fun.assert_not_called()

self.wrapper.read(len(self.data) - 1)
self.mock_fun.assert_not_called()

self.wrapper.seek(0)
def test_function_called_on_close_manual(self):
self.mock_fun.assert_not_called()

self.wrapper.read(len(self.data))
self.mock_fun.assert_not_called()

self.wrapper.seek(0)
self.wrapper.read(len(self.data))
self.mock_fun.assert_not_called()

for _ in range(len(self.data)):
self.wrapper.read(1)
self.mock_fun.assert_not_called()

self.assertEqual(0, len(self.wrapper.read(1)))
self.mock_fun.assert_called_once()

def test_function_called_exactly_once(self):
self.wrapper.read(len(self.data))
self.wrapper.read(1)
self.wrapper.close()
self.mock_fun.assert_called_once()

self.wrapper.seek(0)
self.wrapper.read(len(self.data))
self.wrapper.read(1)
def test_function_called_on_close_context(self):
self.mock_fun.assert_not_called()
with self.wrapper as w:
w.read(len(self.data))
self.mock_fun.assert_called_once()


Expand Down Expand Up @@ -104,6 +87,7 @@ def _read_write_intent(self, write_intent: WriteIntent, full_read_size: int = 1)
read_data = buffer_stream.read(full_read_size)
empty_data = buffer_stream.read(full_read_size)
self.assertEqual(0, len(empty_data))
buffer_stream.close()
return read_data

def test_timeout_called_when_waiting_too_long_for_empty_buffer_slot(self):
Expand Down

0 comments on commit f7c89b4

Please sign in to comment.