Skip to content

Commit

Permalink
Fix partitioned state saving issue (#37389)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxi297 authored Apr 18, 2024
1 parent a49f205 commit d2c8e63
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def on_partition_complete_sentinel(self, sentinel: PartitionCompleteSentinel) ->
partition = sentinel.partition

try:
partition.close()
if sentinel.is_successful:
partition.close()
except Exception as exception:
self._flag_exception(partition.stream_name(), exception)
yield AirbyteTracedException.from_exception(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class PartitionReader:
Generates records from a partition and puts them in a queue.
"""

_IS_SUCCESSFUL = True

def __init__(self, queue: Queue[QueueItem]) -> None:
"""
:param queue: The queue to put the records in.
Expand All @@ -34,7 +36,7 @@ def process_partition(self, partition: Partition) -> None:
try:
for record in partition.read():
self._queue.put(record)
self._queue.put(PartitionCompleteSentinel(partition))
self._queue.put(PartitionCompleteSentinel(partition, self._IS_SUCCESSFUL))
except Exception as e:
self._queue.put(StreamThreadException(e, partition.stream_name()))
self._queue.put(PartitionCompleteSentinel(partition))
self._queue.put(PartitionCompleteSentinel(partition, not self._IS_SUCCESSFUL))
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ class PartitionCompleteSentinel:
Includes a pointer to the partition that was processed.
"""

def __init__(self, partition: Partition):
def __init__(self, partition: Partition, is_successful: bool = True):
"""
:param partition: The partition that was processed
"""
self.partition = partition
self.is_successful = is_successful

def __eq__(self, other: Any) -> bool:
if isinstance(other, PartitionCompleteSentinel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_STREAM_NAME = "stream"
_ANOTHER_STREAM_NAME = "stream2"
_ANY_AIRBYTE_MESSAGE = Mock(spec=AirbyteMessage)
_IS_SUCCESSFUL = True


class TestConcurrentReadProcessor(unittest.TestCase):
Expand Down Expand Up @@ -560,6 +561,27 @@ def test_on_exception_return_trace_message_and_on_stream_complete_return_stream_
)
]

def test_given_partition_completion_is_not_success_then_do_not_close_partition(self):
stream_instances_to_read_from = [self._stream, self._another_stream]

handler = ConcurrentReadProcessor(
stream_instances_to_read_from,
self._partition_enqueuer,
self._thread_pool_manager,
self._logger,
self._slice_logger,
self._message_repository,
self._partition_reader,
)

handler.start_next_partition_generator()
handler.on_partition(self._an_open_partition)
list(handler.on_partition_generation_completed(PartitionGenerationCompletedSentinel(self._stream)))

list(handler.on_partition_complete_sentinel(PartitionCompleteSentinel(self._an_open_partition, not _IS_SUCCESSFUL)))

assert self._an_open_partition.close.call_count == 0

def test_is_done_is_false_if_there_are_any_instances_to_read_from(self):
stream_instances_to_read_from = [self._stream]

Expand Down

0 comments on commit d2c8e63

Please sign in to comment.