diff --git a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py index 24ac315c526e5..f28170490e454 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_read_processor.py @@ -7,6 +7,7 @@ from airbyte_cdk.models import AirbyteMessage, AirbyteStreamStatus from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.message import MessageRepository from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -17,7 +18,9 @@ from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel from airbyte_cdk.sources.utils.record_helper import stream_data_to_airbyte_message from airbyte_cdk.sources.utils.slice_logger import SliceLogger +from airbyte_cdk.utils import AirbyteTracedException from airbyte_cdk.utils.stream_status_utils import as_airbyte_message as stream_status_as_airbyte_message +from airbyte_protocol.models import StreamDescriptor class ConcurrentReadProcessor: @@ -56,6 +59,7 @@ def __init__( self._message_repository = message_repository self._partition_reader = partition_reader self._streams_done: Set[str] = set() + self._exceptions_per_stream_name: dict[str, List[Exception]] = {} def on_partition_generation_completed(self, sentinel: PartitionGenerationCompletedSentinel) -> Iterable[AirbyteMessage]: """ @@ -126,14 +130,16 @@ def on_record(self, record: Record) -> Iterable[AirbyteMessage]: yield message yield from self._message_repository.consume_queue() - def on_exception(self, exception: Exception) -> Iterable[AirbyteMessage]: + def on_exception(self, exception: StreamThreadException) -> Iterable[AirbyteMessage]: """ This method is called when an exception is raised. 1. Stop all running streams 2. Raise the exception """ - yield from self._stop_streams() - raise exception + self._exceptions_per_stream_name.setdefault(exception.stream_name, []).append(exception.exception) + yield AirbyteTracedException.from_exception(exception).as_airbyte_message( + stream_descriptor=StreamDescriptor(name=exception.stream_name) + ) def start_next_partition_generator(self) -> Optional[AirbyteMessage]: """ @@ -177,13 +183,7 @@ def _on_stream_is_done(self, stream_name: str) -> Iterable[AirbyteMessage]: yield from self._message_repository.consume_queue() self._logger.info(f"Finished syncing {stream.name}") self._streams_done.add(stream_name) - yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.COMPLETE) - - def _stop_streams(self) -> Iterable[AirbyteMessage]: - self._thread_pool_manager.shutdown() - for stream_name in self._streams_to_running_partitions.keys(): - stream = self._stream_name_to_instance[stream_name] - if not self._is_stream_done(stream_name): - self._logger.info(f"Marking stream {stream.name} as STOPPED") - self._logger.info(f"Finished syncing {stream.name}") - yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), AirbyteStreamStatus.INCOMPLETE) + stream_status = ( + AirbyteStreamStatus.INCOMPLETE if self._exceptions_per_stream_name.get(stream_name, []) else AirbyteStreamStatus.COMPLETE + ) + yield stream_status_as_airbyte_message(stream.as_airbyte_stream(), stream_status) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source.py b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source.py index f7d65d31aca70..714a6104d00a3 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/concurrent_source.py @@ -9,6 +9,7 @@ from airbyte_cdk.models import AirbyteMessage from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -123,11 +124,6 @@ def _consume_from_queue( concurrent_stream_processor: ConcurrentReadProcessor, ) -> Iterable[AirbyteMessage]: while airbyte_message_or_record_or_exception := queue.get(): - try: - self._threadpool.shutdown_if_exception() - except Exception as exception: - concurrent_stream_processor.on_exception(exception) - yield from self._handle_item( airbyte_message_or_record_or_exception, concurrent_stream_processor, @@ -142,7 +138,7 @@ def _handle_item( concurrent_stream_processor: ConcurrentReadProcessor, ) -> Iterable[AirbyteMessage]: # handle queue item and call the appropriate handler depending on the type of the queue item - if isinstance(queue_item, Exception): + if isinstance(queue_item, StreamThreadException): yield from concurrent_stream_processor.on_exception(queue_item) elif isinstance(queue_item, PartitionGenerationCompletedSentinel): yield from concurrent_stream_processor.on_partition_generation_completed(queue_item) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py new file mode 100644 index 0000000000000..c865bef597326 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/stream_thread_exception.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. + +from typing import Any + + +class StreamThreadException(Exception): + def __init__(self, exception: Exception, stream_name: str): + self._exception = exception + self._stream_name = stream_name + + @property + def stream_name(self) -> str: + return self._stream_name + + @property + def exception(self) -> Exception: + return self._exception + + def __str__(self) -> str: + return f"Exception while syncing stream {self._stream_name}: {self._exception}" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, StreamThreadException): + return self._exception == other._exception and self._stream_name == other._stream_name + return False diff --git a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py index 560989af0a6cd..b6933e6bc3d2a 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/concurrent_source/thread_pool_manager.py @@ -71,26 +71,26 @@ def _prune_futures(self, futures: List[Future[Any]]) -> None: ) futures.pop(index) - def shutdown(self) -> None: + def _shutdown(self) -> None: + # Without a way to stop the threads that have already started, this will not stop the Python application. We are fine today with + # this imperfect approach because we only do this in case of `self._most_recently_seen_exception` which we don't expect to happen self._threadpool.shutdown(wait=False, cancel_futures=True) def is_done(self) -> bool: return all([f.done() for f in self._futures]) - def shutdown_if_exception(self) -> None: - """ - This method will raise if there is an exception so that the caller can use it. - """ - if self._most_recently_seen_exception: - self._stop_and_raise_exception(self._most_recently_seen_exception) - def check_for_errors_and_shutdown(self) -> None: """ Check if any of the futures have an exception, and raise it if so. If all futures are done, shutdown the threadpool. If the futures are not done, raise an exception. :return: """ - self.shutdown_if_exception() + if self._most_recently_seen_exception: + self._logger.exception( + "An unknown exception has occurred while reading concurrently", + exc_info=self._most_recently_seen_exception, + ) + self._stop_and_raise_exception(self._most_recently_seen_exception) exceptions_from_futures = [f for f in [future.exception() for future in self._futures] if f is not None] if exceptions_from_futures: @@ -102,8 +102,8 @@ def check_for_errors_and_shutdown(self) -> None: exception = RuntimeError(f"Failed reading with futures not done: {futures_not_done}") self._stop_and_raise_exception(exception) else: - self.shutdown() + self._shutdown() def _stop_and_raise_exception(self, exception: BaseException) -> None: - self.shutdown() + self._shutdown() raise exception diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/cursor.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/cursor.py index 2a7dcd65c889f..08dee0716c529 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/cursor.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/cursor.py @@ -3,8 +3,7 @@ # import functools from abc import ABC, abstractmethod -from datetime import datetime -from typing import Any, List, Mapping, MutableMapping, Optional, Protocol, Tuple +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Protocol, Tuple from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import MessageRepository @@ -18,11 +17,33 @@ def _extract_value(mapping: Mapping[str, Any], path: List[str]) -> Any: return functools.reduce(lambda a, b: a[b], path, mapping) -class Comparable(Protocol): +class GapType(Protocol): + """ + This is the representation of gaps between two cursor values. Examples: + * if cursor values are datetimes, GapType is timedelta + * if cursor values are integer, GapType will also be integer + """ + + pass + + +class CursorValueType(Protocol): """Protocol for annotating comparable types.""" @abstractmethod - def __lt__(self: "Comparable", other: "Comparable") -> bool: + def __lt__(self: "CursorValueType", other: "CursorValueType") -> bool: + pass + + @abstractmethod + def __ge__(self: "CursorValueType", other: "CursorValueType") -> bool: + pass + + @abstractmethod + def __add__(self: "CursorValueType", other: GapType) -> "CursorValueType": + pass + + @abstractmethod + def __sub__(self: "CursorValueType", other: GapType) -> "CursorValueType": pass @@ -30,7 +51,7 @@ class CursorField: def __init__(self, cursor_field_key: str) -> None: self.cursor_field_key = cursor_field_key - def extract_value(self, record: Record) -> Comparable: + def extract_value(self, record: Record) -> CursorValueType: cursor_value = record.data.get(self.cursor_field_key) if cursor_value is None: raise ValueError(f"Could not find cursor field {self.cursor_field_key} in record") @@ -118,7 +139,10 @@ def __init__( connector_state_converter: AbstractStreamStateConverter, cursor_field: CursorField, slice_boundary_fields: Optional[Tuple[str, str]], - start: Optional[Any], + start: Optional[CursorValueType], + end_provider: Callable[[], CursorValueType], + lookback_window: Optional[GapType] = None, + slice_range: Optional[GapType] = None, ) -> None: self._stream_name = stream_name self._stream_namespace = stream_namespace @@ -129,15 +153,18 @@ def __init__( # To see some example where the slice boundaries might not be defined, check https://github.com/airbytehq/airbyte/blob/1ce84d6396e446e1ac2377362446e3fb94509461/airbyte-integrations/connectors/source-stripe/source_stripe/streams.py#L363-L379 self._slice_boundary_fields = slice_boundary_fields if slice_boundary_fields else tuple() self._start = start + self._end_provider = end_provider self._most_recent_record: Optional[Record] = None self._has_closed_at_least_one_slice = False self.start, self._concurrent_state = self._get_concurrent_state(stream_state) + self._lookback_window = lookback_window + self._slice_range = slice_range @property def state(self) -> MutableMapping[str, Any]: return self._concurrent_state - def _get_concurrent_state(self, state: MutableMapping[str, Any]) -> Tuple[datetime, MutableMapping[str, Any]]: + def _get_concurrent_state(self, state: MutableMapping[str, Any]) -> Tuple[CursorValueType, MutableMapping[str, Any]]: if self._connector_state_converter.is_state_message_compatible(state): return self._start or self._connector_state_converter.zero_value, self._connector_state_converter.deserialize(state) return self._connector_state_converter.convert_from_sequential_state(self._cursor_field, state, self._start) @@ -203,23 +230,20 @@ def _emit_state_message(self) -> None: self._connector_state_manager.update_state_for_stream( self._stream_name, self._stream_namespace, - self._connector_state_converter.convert_to_sequential_state(self._cursor_field, self.state), + self._connector_state_converter.convert_to_state_message(self._cursor_field, self.state), ) - # TODO: if we migrate stored state to the concurrent state format - # (aka stop calling self._connector_state_converter.convert_to_sequential_state`), we'll need to cast datetimes to string or - # int before emitting state state_message = self._connector_state_manager.create_state_message(self._stream_name, self._stream_namespace) self._message_repository.emit_message(state_message) def _merge_partitions(self) -> None: self.state["slices"] = self._connector_state_converter.merge_intervals(self.state["slices"]) - def _extract_from_slice(self, partition: Partition, key: str) -> Comparable: + def _extract_from_slice(self, partition: Partition, key: str) -> CursorValueType: try: _slice = partition.to_slice() if not _slice: raise KeyError(f"Could not find key `{key}` in empty slice") - return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a Comparable + return self._connector_state_converter.parse_value(_slice[key]) # type: ignore # we expect the devs to specify a key that would return a CursorValueType except KeyError as exception: raise KeyError(f"Partition is expected to have key `{key}` but could not be found") from exception @@ -229,3 +253,66 @@ def ensure_at_least_one_state_emitted(self) -> None: called. """ self._emit_state_message() + + def generate_slices(self) -> Iterable[Tuple[CursorValueType, CursorValueType]]: + """ + Generating slices based on a few parameters: + * lookback_window: Buffer to remove from END_KEY of the highest slice + * slice_range: Max difference between two slices. If the difference between two slices is greater, multiple slices will be created + * start: `_split_per_slice_range` will clip any value to `self._start which means that: + * if upper is less than self._start, no slices will be generated + * if lower is less than self._start, self._start will be used as the lower boundary (lookback_window will not be considered in that case) + + Note that the slices will overlap at their boundaries. We therefore expect to have at least the lower or the upper boundary to be + inclusive in the API that is queried. + """ + self._merge_partitions() + + if self._start is not None and self._is_start_before_first_slice(): + yield from self._split_per_slice_range(self._start, self.state["slices"][0][self._connector_state_converter.START_KEY]) + + if len(self.state["slices"]) == 1: + yield from self._split_per_slice_range( + self._calculate_lower_boundary_of_last_slice(self.state["slices"][0][self._connector_state_converter.END_KEY]), + self._end_provider(), + ) + elif len(self.state["slices"]) > 1: + for i in range(len(self.state["slices"]) - 1): + yield from self._split_per_slice_range( + self.state["slices"][i][self._connector_state_converter.END_KEY], + self.state["slices"][i + 1][self._connector_state_converter.START_KEY], + ) + yield from self._split_per_slice_range( + self._calculate_lower_boundary_of_last_slice(self.state["slices"][-1][self._connector_state_converter.END_KEY]), + self._end_provider(), + ) + else: + raise ValueError("Expected at least one slice") + + def _is_start_before_first_slice(self) -> bool: + return self._start is not None and self._start < self.state["slices"][0][self._connector_state_converter.START_KEY] + + def _calculate_lower_boundary_of_last_slice(self, lower_boundary: CursorValueType) -> CursorValueType: + if self._lookback_window: + return lower_boundary - self._lookback_window + return lower_boundary + + def _split_per_slice_range(self, lower: CursorValueType, upper: CursorValueType) -> Iterable[Tuple[CursorValueType, CursorValueType]]: + if lower >= upper: + return + + if self._start and upper < self._start: + return + + lower = max(lower, self._start) if self._start else lower + if not self._slice_range or lower + self._slice_range >= upper: + yield lower, upper + else: + stop_processing = False + current_lower_boundary = lower + while not stop_processing: + current_upper_boundary = min(current_lower_boundary + self._slice_range, upper) + yield current_lower_boundary, current_upper_boundary + current_lower_boundary = current_upper_boundary + if current_upper_boundary >= upper: + stop_processing = True diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py index 3869c6cf9e732..8e63c16a4b2c2 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_enqueuer.py @@ -5,6 +5,7 @@ from queue import Queue from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem @@ -52,4 +53,5 @@ def generate_partitions(self, stream: AbstractStream) -> None: self._queue.put(partition) self._queue.put(PartitionGenerationCompletedSentinel(stream)) except Exception as e: - self._queue.put(e) + self._queue.put(StreamThreadException(e, stream.name)) + self._queue.put(PartitionGenerationCompletedSentinel(stream)) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_reader.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_reader.py index 3df19ca29f926..c0cbf778b6576 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_reader.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partition_reader.py @@ -3,6 +3,7 @@ # from queue import Queue +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.types import PartitionCompleteSentinel, QueueItem @@ -35,4 +36,5 @@ def process_partition(self, partition: Partition) -> None: self._queue.put(record) self._queue.put(PartitionCompleteSentinel(partition)) except Exception as e: - self._queue.put(e) + self._queue.put(StreamThreadException(e, partition.stream_name())) + self._queue.put(PartitionCompleteSentinel(partition)) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/types.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/types.py index fe16b2b0f9ab1..1ffdf6a903ef0 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/types.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/partitions/types.py @@ -2,7 +2,7 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -from typing import Union +from typing import Any, Union from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition @@ -21,6 +21,11 @@ def __init__(self, partition: Partition): """ self.partition = partition + def __eq__(self, other: Any) -> bool: + if isinstance(other, PartitionCompleteSentinel): + return self.partition == other.partition + return False + """ Typedef representing the items that can be added to the ThreadBasedConcurrentStream diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py index 843f477ddb160..e442dc6d97e96 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/state_converters/abstract_stream_state_converter.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, List, MutableMapping, Tuple +from typing import TYPE_CHECKING, Any, List, MutableMapping, Optional, Tuple if TYPE_CHECKING: from airbyte_cdk.sources.streams.concurrent.cursor import CursorField @@ -19,11 +19,65 @@ class AbstractStreamStateConverter(ABC): END_KEY = "end" @abstractmethod + def _from_state_message(self, value: Any) -> Any: + pass + + @abstractmethod + def _to_state_message(self, value: Any) -> Any: + pass + + def __init__(self, is_sequential_state: bool = True): + self._is_sequential_state = is_sequential_state + + def convert_to_state_message(self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any]) -> MutableMapping[str, Any]: + """ + Convert the state message from the concurrency-compatible format to the stream's original format. + + e.g. + { "created": "2021-01-18T21:18:20.000Z" } + """ + if self.is_state_message_compatible(stream_state) and self._is_sequential_state: + legacy_state = stream_state.get("legacy", {}) + latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", [])) + if latest_complete_time is not None: + legacy_state.update({cursor_field.cursor_field_key: self._to_state_message(latest_complete_time)}) + return legacy_state or {} + else: + return self.serialize(stream_state, ConcurrencyCompatibleStateType.date_range) + + def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> Any: + """ + Get the latest time before which all records have been processed. + """ + if not slices: + raise RuntimeError("Expected at least one slice but there were none. This is unexpected; please contact Support.") + + merged_intervals = self.merge_intervals(slices) + first_interval = merged_intervals[0] + return first_interval[self.END_KEY] + def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, Any]: """ Perform any transformations needed for compatibility with the converter. """ - ... + for stream_slice in state.get("slices", []): + stream_slice[self.START_KEY] = self._from_state_message(stream_slice[self.START_KEY]) + stream_slice[self.END_KEY] = self._from_state_message(stream_slice[self.END_KEY]) + return state + + def serialize(self, state: MutableMapping[str, Any], state_type: ConcurrencyCompatibleStateType) -> MutableMapping[str, Any]: + """ + Perform any transformations needed for compatibility with the converter. + """ + serialized_slices = [] + for stream_slice in state.get("slices", []): + serialized_slices.append( + { + self.START_KEY: self._to_state_message(stream_slice[self.START_KEY]), + self.END_KEY: self._to_state_message(stream_slice[self.END_KEY]), + } + ) + return {"slices": serialized_slices, "state_type": state_type.value} @staticmethod def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool: @@ -32,9 +86,9 @@ def is_state_message_compatible(state: MutableMapping[str, Any]) -> bool: @abstractmethod def convert_from_sequential_state( self, - cursor_field: "CursorField", + cursor_field: "CursorField", # to deprecate as it is only needed for sequential state stream_state: MutableMapping[str, Any], - start: Any, + start: Optional[Any], ) -> Tuple[Any, MutableMapping[str, Any]]: """ Convert the state message to the format required by the ConcurrentCursor. @@ -50,23 +104,12 @@ def convert_from_sequential_state( ... @abstractmethod - def convert_to_sequential_state(self, cursor_field: "CursorField", stream_state: MutableMapping[str, Any]) -> MutableMapping[str, Any]: - """ - Convert the state message from the concurrency-compatible format to the stream's original format. - - e.g. - { "created": 1617030403 } - """ - ... - - @abstractmethod - def increment(self, timestamp: Any) -> Any: + def increment(self, value: Any) -> Any: """ Increment a timestamp by a single unit. """ ... - @abstractmethod def merge_intervals(self, intervals: List[MutableMapping[str, Any]]) -> List[MutableMapping[str, Any]]: """ Compute and return a list of merged intervals. @@ -74,7 +117,22 @@ def merge_intervals(self, intervals: List[MutableMapping[str, Any]]) -> List[Mut Intervals may be merged if the start time of the second interval is 1 unit or less (as defined by the `increment` method) than the end time of the first interval. """ - ... + if not intervals: + return [] + + sorted_intervals = sorted(intervals, key=lambda x: (x[self.START_KEY], x[self.END_KEY])) + merged_intervals = [sorted_intervals[0]] + + for interval in sorted_intervals[1:]: + last_end_time = merged_intervals[-1][self.END_KEY] + current_start_time = interval[self.START_KEY] + if bool(self.increment(last_end_time) >= current_start_time): + merged_end_time = max(last_end_time, interval[self.END_KEY]) + merged_intervals[-1][self.END_KEY] = merged_end_time + else: + merged_intervals.append(interval) + + return merged_intervals @abstractmethod def parse_value(self, value: Any) -> Any: diff --git a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py index 83f8a44b23db2..226ee79c04040 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/streams/concurrent/state_converters/datetime_stream_state_converter.py @@ -3,8 +3,8 @@ # from abc import abstractmethod -from datetime import datetime, timedelta -from typing import Any, List, MutableMapping, Optional, Tuple +from datetime import datetime, timedelta, timezone +from typing import Any, Callable, MutableMapping, Optional, Tuple import pendulum from airbyte_cdk.sources.streams.concurrent.cursor import CursorField @@ -16,6 +16,12 @@ class DateTimeStreamStateConverter(AbstractStreamStateConverter): + def _from_state_message(self, value: Any) -> Any: + return self.parse_timestamp(value) + + def _to_state_message(self, value: Any) -> Any: + return self.output_format(value) + @property @abstractmethod def _zero_value(self) -> Any: @@ -25,6 +31,10 @@ def _zero_value(self) -> Any: def zero_value(self) -> datetime: return self.parse_timestamp(self._zero_value) + @classmethod + def get_end_provider(cls) -> Callable[[], datetime]: + return lambda: datetime.now(timezone.utc) + @abstractmethod def increment(self, timestamp: datetime) -> datetime: ... @@ -37,41 +47,17 @@ def parse_timestamp(self, timestamp: Any) -> datetime: def output_format(self, timestamp: datetime) -> Any: ... - def deserialize(self, state: MutableMapping[str, Any]) -> MutableMapping[str, Any]: - for stream_slice in state.get("slices", []): - stream_slice[self.START_KEY] = self.parse_timestamp(stream_slice[self.START_KEY]) - stream_slice[self.END_KEY] = self.parse_timestamp(stream_slice[self.END_KEY]) - return state - def parse_value(self, value: Any) -> Any: """ Parse the value of the cursor field into a comparable value. """ return self.parse_timestamp(value) - def merge_intervals(self, intervals: List[MutableMapping[str, datetime]]) -> List[MutableMapping[str, datetime]]: - if not intervals: - return [] - - sorted_intervals = sorted(intervals, key=lambda x: (x[self.START_KEY], x[self.END_KEY])) - merged_intervals = [sorted_intervals[0]] - - for interval in sorted_intervals[1:]: - last_end_time = merged_intervals[-1][self.END_KEY] - current_start_time = interval[self.START_KEY] - if self._compare_intervals(last_end_time, current_start_time): - merged_end_time = max(last_end_time, interval[self.END_KEY]) - merged_intervals[-1][self.END_KEY] = merged_end_time - else: - merged_intervals.append(interval) - - return merged_intervals - def _compare_intervals(self, end_time: Any, start_time: Any) -> bool: return bool(self.increment(end_time) >= start_time) def convert_from_sequential_state( - self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], start: datetime + self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], start: Optional[datetime] ) -> Tuple[datetime, MutableMapping[str, Any]]: """ Convert the state message to the format required by the ConcurrentCursor. @@ -92,7 +78,7 @@ def convert_from_sequential_state( # Create a slice to represent the records synced during prior syncs. # The start and end are the same to avoid confusion as to whether the records for this slice # were actually synced - slices = [{self.START_KEY: sync_start, self.END_KEY: sync_start}] + slices = [{self.START_KEY: start if start is not None else sync_start, self.END_KEY: sync_start}] return sync_start, { "state_type": ConcurrencyCompatibleStateType.date_range.value, @@ -100,8 +86,8 @@ def convert_from_sequential_state( "legacy": stream_state, } - def _get_sync_start(self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], start: Optional[Any]) -> datetime: - sync_start = self.parse_timestamp(start) if start is not None else self.zero_value + def _get_sync_start(self, cursor_field: CursorField, stream_state: MutableMapping[str, Any], start: Optional[datetime]) -> datetime: + sync_start = start if start is not None else self.zero_value prev_sync_low_water_mark = ( self.parse_timestamp(stream_state[cursor_field.cursor_field_key]) if cursor_field.cursor_field_key in stream_state else None ) @@ -110,33 +96,6 @@ def _get_sync_start(self, cursor_field: CursorField, stream_state: MutableMappin else: return sync_start - def convert_to_sequential_state(self, cursor_field: CursorField, stream_state: MutableMapping[str, Any]) -> MutableMapping[str, Any]: - """ - Convert the state message from the concurrency-compatible format to the stream's original format. - - e.g. - { "created": "2021-01-18T21:18:20.000Z" } - """ - if self.is_state_message_compatible(stream_state): - legacy_state = stream_state.get("legacy", {}) - latest_complete_time = self._get_latest_complete_time(stream_state.get("slices", [])) - if latest_complete_time is not None: - legacy_state.update({cursor_field.cursor_field_key: self.output_format(latest_complete_time)}) - return legacy_state or {} - else: - return stream_state - - def _get_latest_complete_time(self, slices: List[MutableMapping[str, Any]]) -> Optional[datetime]: - """ - Get the latest time before which all records have been processed. - """ - if not slices: - raise RuntimeError("Expected at least one slice but there were none. This is unexpected; please contact Support.") - - merged_intervals = self.merge_intervals(slices) - first_interval = merged_intervals[0] - return first_interval[self.END_KEY] - class EpochValueConcurrentStreamStateConverter(DateTimeStreamStateConverter): """ diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py index 290bd26b9384f..b233e8039bc83 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_builder.py @@ -81,6 +81,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: self._cursor_field, self._cursor_boundaries, None, + EpochValueConcurrentStreamStateConverter.get_end_provider() ) if self._cursor_field else FinalStateCursor(stream_name=stream.name, stream_namespace=stream.namespace, message_repository=self.message_repository), diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py index 2090a4dd1c14a..de2ca049edf1c 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/stream_facade_scenarios.py @@ -1,6 +1,7 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.streams.concurrent.cursor import CursorField from unit_tests.sources.file_based.scenarios.scenario_builder import IncrementalScenarioConfig, TestScenarioBuilder from unit_tests.sources.streams.concurrent.scenarios.stream_facade_builder import StreamFacadeSourceBuilder @@ -157,7 +158,7 @@ ] } ) - .set_expected_read_error(ValueError, "test exception") + .set_expected_read_error(StreamThreadException, "Exception while syncing stream stream1: test exception") .build() ) diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py index 3164f9ab565dc..4a0094c3bc463 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/scenarios/thread_based_concurrent_stream_scenarios.py @@ -3,6 +3,7 @@ # import logging +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.message import InMemoryMessageRepository from airbyte_cdk.sources.streams.concurrent.cursor import FinalStateCursor from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream @@ -317,7 +318,7 @@ {"data": {"id": "1"}, "stream": "stream1"}, ] ) - .set_expected_read_error(ValueError, "test exception") + .set_expected_read_error(StreamThreadException, "Exception while syncing stream stream1: test exception") .set_expected_catalog( { "streams": [ diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py index 3e0e00b62d32f..666ff6df3ba2d 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_concurrent_read_processor.py @@ -20,6 +20,7 @@ from airbyte_cdk.models import Type as MessageType from airbyte_cdk.sources.concurrent_source.concurrent_read_processor import ConcurrentReadProcessor from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.message import LogMessage, MessageRepository from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream @@ -479,66 +480,7 @@ def test_on_record_emits_status_message_on_first_record_with_repository_message( assert expected_messages == messages @freezegun.freeze_time("2020-01-01T00:00:00") - def test_on_exception_stops_streams_and_raises_an_exception(self): - stream_instances_to_read_from = [self._stream, self._another_stream] - - handler = ConcurrentReadProcessor( - stream_instances_to_read_from, - self._partition_enqueuer, - self._thread_pool_manager, - self._logger, - self._slice_logger, - self._message_repository, - self._partition_reader, - ) - - handler.start_next_partition_generator() - - another_stream = Mock(spec=AbstractStream) - another_stream.name = _STREAM_NAME - another_stream.as_airbyte_stream.return_value = AirbyteStream( - name=_ANOTHER_STREAM_NAME, - json_schema={}, - supported_sync_modes=[SyncMode.full_refresh], - ) - - exception = RuntimeError("Something went wrong") - - messages = [] - - with self.assertRaises(RuntimeError): - for m in handler.on_exception(exception): - messages.append(m) - - expected_message = [ - AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - emitted_at=1577836800000.0, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name=_STREAM_NAME), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE) - ), - ), - ), - AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - emitted_at=1577836800000.0, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name=_ANOTHER_STREAM_NAME), - status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE), - ), - ), - ), - ] - - assert messages == expected_message - self._thread_pool_manager.shutdown.assert_called_once() - - @freezegun.freeze_time("2020-01-01T00:00:00") - def test_on_exception_does_not_stop_streams_that_are_already_done(self): + def test_on_exception_return_trace_message_and_on_stream_complete_return_stream_status(self): stream_instances_to_read_from = [self._stream, self._another_stream] handler = ConcurrentReadProcessor( @@ -564,15 +506,13 @@ def test_on_exception_does_not_stop_streams_that_are_already_done(self): supported_sync_modes=[SyncMode.full_refresh], ) - exception = RuntimeError("Something went wrong") - - messages = [] + exception = StreamThreadException(RuntimeError("Something went wrong"), _STREAM_NAME) - with self.assertRaises(RuntimeError): - for m in handler.on_exception(exception): - messages.append(m) + exception_messages = list(handler.on_exception(exception)) + assert len(exception_messages) == 1 + assert exception_messages[0].type == MessageType.TRACE - expected_message = [ + assert list(handler.on_partition_complete_sentinel(PartitionCompleteSentinel(self._an_open_partition))) == [ AirbyteMessage( type=MessageType.TRACE, trace=AirbyteTraceMessage( @@ -585,9 +525,6 @@ def test_on_exception_does_not_stop_streams_that_are_already_done(self): ) ] - assert messages == expected_message - self._thread_pool_manager.shutdown.assert_called_once() - def test_is_done_is_false_if_there_are_any_instances_to_read_from(self): stream_instances_to_read_from = [self._stream] diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_cursor.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_cursor.py index 94ed5211eabb9..b8fa8b2f79e0c 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_cursor.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_cursor.py @@ -1,21 +1,25 @@ # # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # +from datetime import datetime, timedelta, timezone from typing import Any, Mapping, Optional from unittest import TestCase from unittest.mock import Mock +import freezegun import pytest from airbyte_cdk.sources.connector_state_manager import ConnectorStateManager from airbyte_cdk.sources.message import MessageRepository -from airbyte_cdk.sources.streams.concurrent.cursor import Comparable, ConcurrentCursor, CursorField +from airbyte_cdk.sources.streams.concurrent.cursor import ConcurrentCursor, CursorField, CursorValueType from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record +from airbyte_cdk.sources.streams.concurrent.state_converters.abstract_stream_state_converter import ConcurrencyCompatibleStateType from airbyte_cdk.sources.streams.concurrent.state_converters.datetime_stream_state_converter import EpochValueConcurrentStreamStateConverter _A_STREAM_NAME = "a stream name" _A_STREAM_NAMESPACE = "a stream namespace" _A_CURSOR_FIELD_KEY = "a_cursor_field_key" +_NO_STATE = {} _NO_PARTITION_IDENTIFIER = None _NO_SLICE = None _NO_SLICE_BOUNDARIES = None @@ -23,6 +27,7 @@ _UPPER_SLICE_BOUNDARY_FIELD = "upper_boundary" _SLICE_BOUNDARY_FIELDS = (_LOWER_SLICE_BOUNDARY_FIELD, _UPPER_SLICE_BOUNDARY_FIELD) _A_VERY_HIGH_CURSOR_VALUE = 1000000000 +_NO_LOOKBACK_WINDOW = timedelta(seconds=0) def _partition(_slice: Optional[Mapping[str, Any]]) -> Partition: @@ -31,27 +36,28 @@ def _partition(_slice: Optional[Mapping[str, Any]]) -> Partition: return partition -def _record(cursor_value: Comparable) -> Record: +def _record(cursor_value: CursorValueType) -> Record: return Record(data={_A_CURSOR_FIELD_KEY: cursor_value}, stream_name=_A_STREAM_NAME) -class ConcurrentCursorTest(TestCase): +class ConcurrentCursorStateTest(TestCase): def setUp(self) -> None: self._message_repository = Mock(spec=MessageRepository) self._state_manager = Mock(spec=ConnectorStateManager) - self._state_converter = EpochValueConcurrentStreamStateConverter() - def _cursor_with_slice_boundary_fields(self) -> ConcurrentCursor: + def _cursor_with_slice_boundary_fields(self, is_sequential_state=True) -> ConcurrentCursor: return ConcurrentCursor( _A_STREAM_NAME, _A_STREAM_NAMESPACE, {}, self._message_repository, self._state_manager, - self._state_converter, + EpochValueConcurrentStreamStateConverter(is_sequential_state), CursorField(_A_CURSOR_FIELD_KEY), _SLICE_BOUNDARY_FIELDS, None, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, ) def _cursor_without_slice_boundary_fields(self) -> ConcurrentCursor: @@ -61,10 +67,12 @@ def _cursor_without_slice_boundary_fields(self) -> ConcurrentCursor: {}, self._message_repository, self._state_manager, - self._state_converter, + EpochValueConcurrentStreamStateConverter(is_sequential_state=True), CursorField(_A_CURSOR_FIELD_KEY), None, None, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, ) def test_given_boundary_fields_when_close_partition_then_emit_state(self) -> None: @@ -82,6 +90,24 @@ def test_given_boundary_fields_when_close_partition_then_emit_state(self) -> Non {_A_CURSOR_FIELD_KEY: 0}, # State message is updated to the legacy format before being emitted ) + def test_given_state_not_sequential_when_close_partition_then_emit_state(self) -> None: + cursor = self._cursor_with_slice_boundary_fields(is_sequential_state=False) + cursor.close_partition( + _partition( + {_LOWER_SLICE_BOUNDARY_FIELD: 12, _UPPER_SLICE_BOUNDARY_FIELD: 30}, + ) + ) + + self._message_repository.emit_message.assert_called_once_with(self._state_manager.create_state_message.return_value) + self._state_manager.update_state_for_stream.assert_called_once_with( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "slices": [{"end": 0, "start": 0}, {"end": 30, "start": 12}], + "state_type": "date-range" + }, + ) + def test_given_boundary_fields_when_close_partition_then_emit_updated_state(self) -> None: self._cursor_with_slice_boundary_fields().close_partition( _partition( @@ -137,3 +163,265 @@ def test_given_slice_boundaries_not_matching_slice_when_close_partition_then_rai cursor = self._cursor_with_slice_boundary_fields() with pytest.raises(KeyError): cursor.close_partition(_partition({"not_matching_key": "value"})) + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_no_state_when_generate_slices_then_create_slice_from_start_to_end(self): + start = datetime.fromtimestamp(10, timezone.utc) + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + _NO_STATE, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + start, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(10, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + ] + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_one_slice_when_generate_slices_then_create_slice_from_slice_upper_boundary_to_end(self): + start = datetime.fromtimestamp(0, timezone.utc) + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "state_type": ConcurrencyCompatibleStateType.date_range.value, + "slices": [ + {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + ] + }, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + start, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + ] + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_start_after_slices_when_generate_slices_then_generate_from_start(self): + start = datetime.fromtimestamp(30, timezone.utc) + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "state_type": ConcurrencyCompatibleStateType.date_range.value, + "slices": [ + {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + ] + }, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + start, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + ] + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_state_with_gap_and_start_after_slices_when_generate_slices_then_generate_from_start(self): + start = datetime.fromtimestamp(30, timezone.utc) + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "state_type": ConcurrencyCompatibleStateType.date_range.value, + "slices": [ + {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 10}, + {EpochValueConcurrentStreamStateConverter.START_KEY: 15, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + ] + }, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + start, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + ] + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_small_slice_range_when_generate_slices_then_create_many_slices(self): + start = datetime.fromtimestamp(0, timezone.utc) + small_slice_range = timedelta(seconds=10) + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "state_type": ConcurrencyCompatibleStateType.date_range.value, + "slices": [ + {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + ] + }, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + start, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, + small_slice_range, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(30, timezone.utc)), + (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(40, timezone.utc)), + (datetime.fromtimestamp(40, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + ] + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_difference_between_slices_match_slice_range_when_generate_slices_then_create_one_slice(self): + start = datetime.fromtimestamp(0, timezone.utc) + small_slice_range = timedelta(seconds=10) + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "state_type": ConcurrencyCompatibleStateType.date_range.value, + "slices": [ + {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 30}, + {EpochValueConcurrentStreamStateConverter.START_KEY: 40, EpochValueConcurrentStreamStateConverter.END_KEY: 50}, + ] + }, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + start, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, + small_slice_range, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(40, timezone.utc)), + ] + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_non_continuous_state_when_generate_slices_then_create_slices_between_gaps_and_after(self): + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "state_type": ConcurrencyCompatibleStateType.date_range.value, + "slices": [ + {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 10}, + {EpochValueConcurrentStreamStateConverter.START_KEY: 20, EpochValueConcurrentStreamStateConverter.END_KEY: 25}, + {EpochValueConcurrentStreamStateConverter.START_KEY: 30, EpochValueConcurrentStreamStateConverter.END_KEY: 40}, + ] + }, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + None, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(10, timezone.utc), datetime.fromtimestamp(20, timezone.utc)), + (datetime.fromtimestamp(25, timezone.utc), datetime.fromtimestamp(30, timezone.utc)), + (datetime.fromtimestamp(40, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + ] + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_lookback_window_when_generate_slices_then_apply_lookback_on_most_recent_slice(self): + start = datetime.fromtimestamp(0, timezone.utc) + lookback_window = timedelta(seconds=10) + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "state_type": ConcurrencyCompatibleStateType.date_range.value, + "slices": [ + {EpochValueConcurrentStreamStateConverter.START_KEY: 0, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + {EpochValueConcurrentStreamStateConverter.START_KEY: 30, EpochValueConcurrentStreamStateConverter.END_KEY: 40}, + ] + }, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + start, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + lookback_window, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(30, timezone.utc)), + (datetime.fromtimestamp(30, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + ] + + @freezegun.freeze_time(time_to_freeze=datetime.fromtimestamp(50, timezone.utc)) + def test_given_start_is_before_first_slice_lower_boundary_when_generate_slices_then_generate_slice_before(self): + start = datetime.fromtimestamp(0, timezone.utc) + cursor = ConcurrentCursor( + _A_STREAM_NAME, + _A_STREAM_NAMESPACE, + { + "state_type": ConcurrencyCompatibleStateType.date_range.value, + "slices": [ + {EpochValueConcurrentStreamStateConverter.START_KEY: 10, EpochValueConcurrentStreamStateConverter.END_KEY: 20}, + ] + }, + self._message_repository, + self._state_manager, + EpochValueConcurrentStreamStateConverter(is_sequential_state=False), + CursorField(_A_CURSOR_FIELD_KEY), + _SLICE_BOUNDARY_FIELDS, + start, + EpochValueConcurrentStreamStateConverter.get_end_provider(), + _NO_LOOKBACK_WINDOW, + ) + + slices = list(cursor.generate_slices()) + + assert slices == [ + (datetime.fromtimestamp(0, timezone.utc), datetime.fromtimestamp(10, timezone.utc)), + (datetime.fromtimestamp(20, timezone.utc), datetime.fromtimestamp(50, timezone.utc)), + ] diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py index 534dbd580787c..aeaf5ae5a50ba 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_datetime_state_converter.py @@ -98,7 +98,7 @@ def test_concurrent_stream_state_converter_is_state_message_compatible(converter ), pytest.param( EpochValueConcurrentStreamStateConverter(), - 1617030403, + datetime.fromtimestamp(1617030403, timezone.utc), {}, datetime(2021, 3, 29, 15, 6, 43, tzinfo=timezone.utc), id="epoch-converter-no-state-with-start-start-is-start", @@ -112,14 +112,14 @@ def test_concurrent_stream_state_converter_is_state_message_compatible(converter ), pytest.param( EpochValueConcurrentStreamStateConverter(), - 1617030404, + datetime.fromtimestamp(1617030404, timezone.utc), {"created_at": 1617030403}, datetime(2021, 3, 29, 15, 6, 44, tzinfo=timezone.utc), id="epoch-converter-state-before-start-start-is-start", ), pytest.param( EpochValueConcurrentStreamStateConverter(), - 1617030403, + datetime.fromtimestamp(1617030403, timezone.utc), {"created_at": 1617030404}, datetime(2021, 3, 29, 15, 6, 44, tzinfo=timezone.utc), id="epoch-converter-state-after-start-start-is-from-state", @@ -133,7 +133,7 @@ def test_concurrent_stream_state_converter_is_state_message_compatible(converter ), pytest.param( IsoMillisConcurrentStreamStateConverter(), - "2021-08-22T05:03:27.000Z", + datetime(2021, 8, 22, 5, 3, 27, tzinfo=timezone.utc), {}, datetime(2021, 8, 22, 5, 3, 27, tzinfo=timezone.utc), id="isomillis-converter-no-state-with-start-start-is-start", @@ -147,14 +147,14 @@ def test_concurrent_stream_state_converter_is_state_message_compatible(converter ), pytest.param( IsoMillisConcurrentStreamStateConverter(), - "2022-08-22T05:03:27.000Z", + datetime(2022, 8, 22, 5, 3, 27, tzinfo=timezone.utc), {"created_at": "2021-08-22T05:03:27.000Z"}, datetime(2022, 8, 22, 5, 3, 27, tzinfo=timezone.utc), id="isomillis-converter-state-before-start-start-is-start", ), pytest.param( IsoMillisConcurrentStreamStateConverter(), - "2022-08-22T05:03:27.000Z", + datetime(2022, 8, 22, 5, 3, 27, tzinfo=timezone.utc), {"created_at": "2023-08-22T05:03:27.000Z"}, datetime(2023, 8, 22, 5, 3, 27, tzinfo=timezone.utc), id="isomillis-converter-state-after-start-start-is-from-state", @@ -170,7 +170,7 @@ def test_get_sync_start(converter, start, state, expected_start): [ pytest.param( EpochValueConcurrentStreamStateConverter(), - 0, + datetime.fromtimestamp(0, timezone.utc), {}, { "legacy": {}, @@ -186,13 +186,13 @@ def test_get_sync_start(converter, start, state, expected_start): ), pytest.param( EpochValueConcurrentStreamStateConverter(), - 1617030403, + datetime.fromtimestamp(1577836800, timezone.utc), {"created": 1617030403}, { "state_type": "date-range", "slices": [ { - "start": datetime(2021, 3, 29, 15, 6, 43, tzinfo=timezone.utc), + "start": datetime(2020, 1, 1, tzinfo=timezone.utc), "end": datetime(2021, 3, 29, 15, 6, 43, tzinfo=timezone.utc), } ], @@ -202,13 +202,13 @@ def test_get_sync_start(converter, start, state, expected_start): ), pytest.param( IsoMillisConcurrentStreamStateConverter(), - "2020-01-01T00:00:00.000Z", + datetime(2020, 1, 1, tzinfo=timezone.utc), {"created": "2021-08-22T05:03:27.000Z"}, { "state_type": "date-range", "slices": [ { - "start": datetime(2021, 8, 22, 5, 3, 27, tzinfo=timezone.utc), + "start": datetime(2020, 1, 1, tzinfo=timezone.utc), "end": datetime(2021, 8, 22, 5, 3, 27, tzinfo=timezone.utc), } ], @@ -338,7 +338,7 @@ def test_convert_from_sequential_state(converter, start, sequential_state, expec ], ) def test_convert_to_sequential_state(converter, concurrent_state, expected_output_state): - assert converter.convert_to_sequential_state(CursorField("created"), concurrent_state) == expected_output_state + assert converter.convert_to_state_message(CursorField("created"), concurrent_state) == expected_output_state @pytest.mark.parametrize( @@ -366,4 +366,4 @@ def test_convert_to_sequential_state(converter, concurrent_state, expected_outpu ) def test_convert_to_sequential_state_no_slices_returns_legacy_state(converter, concurrent_state, expected_output_state): with pytest.raises(RuntimeError): - converter.convert_to_sequential_state(CursorField("created"), concurrent_state) + converter.convert_to_state_message(CursorField("created"), concurrent_state) diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py index bdcd9ad43318c..d11154e712978 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_enqueuer.py @@ -7,6 +7,7 @@ from unittest.mock import Mock, patch from airbyte_cdk.sources.concurrent_source.partition_generation_completed_sentinel import PartitionGenerationCompletedSentinel +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream from airbyte_cdk.sources.streams.concurrent.partition_enqueuer import PartitionEnqueuer @@ -14,6 +15,7 @@ from airbyte_cdk.sources.streams.concurrent.partitions.types import QueueItem _SOME_PARTITIONS: List[Partition] = [Mock(spec=Partition), Mock(spec=Partition)] +_A_STREAM_NAME = "a_stream_name" class PartitionEnqueuerTest(unittest.TestCase): @@ -57,14 +59,16 @@ def test_given_partition_but_limit_reached_when_generate_partitions_then_wait_un assert mocked_sleep.call_count == 2 - def test_given_exception_when_generate_partitions_then_raise(self): + def test_given_exception_when_generate_partitions_then_return_exception_and_sentinel(self): stream = Mock(spec=AbstractStream) + stream.name = _A_STREAM_NAME exception = ValueError() stream.generate_partitions.side_effect = self._partitions_before_raising(_SOME_PARTITIONS, exception) self._partition_generator.generate_partitions(stream) - assert self._consume_queue() == _SOME_PARTITIONS + [exception] + queue_content = self._consume_queue() + assert queue_content == _SOME_PARTITIONS + [StreamThreadException(exception, _A_STREAM_NAME), PartitionGenerationCompletedSentinel(stream)] def _partitions_before_raising(self, partitions: List[Partition], exception: Exception) -> Callable[[], Iterable[Partition]]: def inner_function() -> Iterable[Partition]: @@ -83,7 +87,7 @@ def _a_stream(partitions: List[Partition]) -> AbstractStream: def _consume_queue(self) -> List[QueueItem]: queue_content: List[QueueItem] = [] while queue_item := self._queue.get(): - if isinstance(queue_item, (PartitionGenerationCompletedSentinel, Exception)): + if isinstance(queue_item, PartitionGenerationCompletedSentinel): queue_content.append(queue_item) break queue_content.append(queue_item) diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_reader.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_reader.py index 9e9fb89739496..226652be82a1c 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_reader.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_partition_reader.py @@ -7,6 +7,7 @@ from unittest.mock import Mock import pytest +from airbyte_cdk.sources.concurrent_source.stream_thread_exception import StreamThreadException from airbyte_cdk.sources.streams.concurrent.partition_reader import PartitionReader from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition from airbyte_cdk.sources.streams.concurrent.partitions.record import Record @@ -32,26 +33,22 @@ def test_given_no_records_when_process_partition_then_only_emit_sentinel(self): break def test_given_read_partition_successful_when_process_partition_then_queue_records_and_sentinel(self): - self._partition_reader.process_partition(self._a_partition(_RECORDS)) + partition = self._a_partition(_RECORDS) + self._partition_reader.process_partition(partition) - actual_records = [] - while queue_item := self._queue.get(): - if isinstance(queue_item, PartitionCompleteSentinel): - break - actual_records.append(queue_item) + queue_content = self._consume_queue() - assert _RECORDS == actual_records + assert queue_content == _RECORDS + [PartitionCompleteSentinel(partition)] - def test_given_exception_when_process_partition_then_queue_records_and_raise_exception(self): + def test_given_exception_when_process_partition_then_queue_records_and_exception_and_sentinel(self): partition = Mock() exception = ValueError() partition.read.side_effect = self._read_with_exception(_RECORDS, exception) - self._partition_reader.process_partition(partition) - for i in range(len(_RECORDS)): - assert self._queue.get() == _RECORDS[i] - assert self._queue.get() == exception + queue_content = self._consume_queue() + + assert queue_content == _RECORDS + [StreamThreadException(exception, partition.stream_name()), PartitionCompleteSentinel(partition)] def _a_partition(self, records: List[Record]) -> Partition: partition = Mock(spec=Partition) @@ -65,3 +62,11 @@ def mocked_function() -> Iterable[Record]: raise exception return mocked_function + + def _consume_queue(self): + queue_content = [] + while queue_item := self._queue.get(): + queue_content.append(queue_item) + if isinstance(queue_item, PartitionCompleteSentinel): + break + return queue_content diff --git a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py index 102cf7cdd4482..197f9b3431e85 100644 --- a/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py +++ b/airbyte-cdk/python/unit_tests/sources/streams/concurrent/test_thread_pool_manager.py @@ -21,27 +21,6 @@ def test_submit_calls_underlying_thread_pool(self): assert len(self._thread_pool_manager._futures) == 1 - def test_given_no_exceptions_when_shutdown_if_exception_then_do_not_raise(self): - future = Mock(spec=Future) - future.exception.return_value = None - future.done.side_effect = [True, True] - - self._thread_pool_manager._futures = [future] - self._thread_pool_manager.prune_to_validate_has_reached_futures_limit() - - self._thread_pool_manager.shutdown_if_exception() # do not raise - - def test_given_exception_when_shutdown_if_exception_then_raise(self): - future = Mock(spec=Future) - future.exception.return_value = RuntimeError - future.done.side_effect = [True, True] - - self._thread_pool_manager._futures = [future] - self._thread_pool_manager.prune_to_validate_has_reached_futures_limit() - - with self.assertRaises(RuntimeError): - self._thread_pool_manager.shutdown_if_exception() - def test_given_exception_during_pruning_when_check_for_errors_and_shutdown_then_shutdown_and_raise(self): future = Mock(spec=Future) future.exception.return_value = RuntimeError @@ -54,10 +33,6 @@ def test_given_exception_during_pruning_when_check_for_errors_and_shutdown_then_ self._thread_pool_manager.check_for_errors_and_shutdown() self._threadpool.shutdown.assert_called_with(wait=False, cancel_futures=True) - def test_shutdown(self): - self._thread_pool_manager.shutdown() - self._threadpool.shutdown.assert_called_with(wait=False, cancel_futures=True) - def test_is_done_is_false_if_not_all_futures_are_done(self): future = Mock(spec=Future) future.done.return_value = False diff --git a/airbyte-cdk/python/unit_tests/sources/test_concurrent_source.py b/airbyte-cdk/python/unit_tests/sources/test_concurrent_source.py deleted file mode 100644 index 9ec0a293cdf6c..0000000000000 --- a/airbyte-cdk/python/unit_tests/sources/test_concurrent_source.py +++ /dev/null @@ -1,112 +0,0 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# -import concurrent -import logging -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple -from unittest.mock import Mock - -from airbyte_cdk.models import SyncMode -from airbyte_cdk.sources.concurrent_source.concurrent_source import ConcurrentSource -from airbyte_cdk.sources.concurrent_source.thread_pool_manager import ThreadPoolManager -from airbyte_cdk.sources.message import InMemoryMessageRepository, MessageRepository -from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream -from airbyte_cdk.sources.streams.concurrent.availability_strategy import StreamAvailability, StreamAvailable, StreamUnavailable -from airbyte_cdk.sources.streams.concurrent.cursor import Cursor, FinalStateCursor -from airbyte_cdk.sources.streams.concurrent.partitions.partition import Partition -from airbyte_cdk.sources.streams.concurrent.partitions.record import Record -from airbyte_protocol.models import AirbyteStream - -logger = logging.getLogger("airbyte") - - -class _MockSource(ConcurrentSource): - def __init__( - self, - check_lambda: Callable[[], Tuple[bool, Optional[Any]]] = None, - per_stream: bool = True, - message_repository: MessageRepository = InMemoryMessageRepository(), - threadpool: ThreadPoolManager = ThreadPoolManager( - concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="workerpool"), logger - ), - exception_on_missing_stream: bool = True, - ): - super().__init__(threadpool, Mock(), Mock(), message_repository) - self.check_lambda = check_lambda - self.per_stream = per_stream - self.exception_on_missing_stream = exception_on_missing_stream - self._message_repository = message_repository - - -MESSAGE_FROM_REPOSITORY = Mock() - - -class _MockStream(AbstractStream): - def __init__(self, name: str, message_repository: MessageRepository, available: bool = True, json_schema: Dict[str, Any] = {}): - self._name = name - self._available = available - self._json_schema = json_schema - self._message_repository = message_repository - - def generate_partitions(self) -> Iterable[Partition]: - yield _MockPartition(self._name) - - @property - def name(self) -> str: - return self._name - - @property - def cursor_field(self) -> Optional[str]: - raise NotImplementedError - - def check_availability(self) -> StreamAvailability: - if self._available: - return StreamAvailable() - else: - return StreamUnavailable("stream is unavailable") - - def get_json_schema(self) -> Mapping[str, Any]: - return self._json_schema - - def as_airbyte_stream(self) -> AirbyteStream: - return AirbyteStream(name=self.name, json_schema=self.get_json_schema(), supported_sync_modes=[SyncMode.full_refresh]) - - def log_stream_sync_configuration(self) -> None: - raise NotImplementedError - - @property - def cursor(self) -> Cursor: - return FinalStateCursor(stream_name=self._name, stream_namespace=None, message_repository=self._message_repository) - - -class _MockPartition(Partition): - def __init__(self, name: str): - self._name = name - self._closed = False - - def read(self) -> Iterable[Record]: - yield from [Record({"key": "value"}, self._name)] - - def to_slice(self) -> Optional[Mapping[str, Any]]: - return {} - - def stream_name(self) -> str: - return self._name - - def close(self) -> None: - self._closed = True - - def is_closed(self) -> bool: - return self._closed - - def __hash__(self) -> int: - return hash(self._name) - - -def test_concurrent_source_reading_from_no_streams(): - message_repository = InMemoryMessageRepository() - stream = _MockStream("my_stream", message_repository,False, {}) - source = _MockSource(message_repository=message_repository) - messages = [] - for m in source.read([stream]): - messages.append(m) diff --git a/airbyte-cdk/python/unit_tests/sources/test_source_read.py b/airbyte-cdk/python/unit_tests/sources/test_source_read.py index 8aaeed6b777ec..61b4f0229534e 100644 --- a/airbyte-cdk/python/unit_tests/sources/test_source_read.py +++ b/airbyte-cdk/python/unit_tests/sources/test_source_read.py @@ -126,7 +126,8 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e config = {} catalog = _create_configured_catalog(source._streams) - messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, None) + # FIXME this is currently unused in this test + # messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, None) messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, None) expected_messages = [ @@ -267,7 +268,7 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_no_e ), ), ] - _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source) + _verify_messages(expected_messages, messages_from_concurrent_source) @freezegun.freeze_time("2020-01-01T00:00:00") @@ -283,53 +284,9 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_a_tr messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException) messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, AirbyteTracedException) - expected_messages = [ - AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - emitted_at=1577836800000.0, - error=None, - estimate=None, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED) - ), - ), - ), - AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - emitted_at=1577836800000.0, - error=None, - estimate=None, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING) - ), - ), - ), - AirbyteMessage( - type=MessageType.RECORD, - record=AirbyteRecordMessage( - stream="stream0", - data=records[0], - emitted_at=1577836800000, - ), - ), - AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - emitted_at=1577836800000.0, - error=None, - estimate=None, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE) - ), - ), - ), - ] - _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source) + _assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source) + _assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source) + _assert_errors(messages_from_abstract_source, messages_from_concurrent_source) @freezegun.freeze_time("2020-01-01T00:00:00") @@ -346,53 +303,38 @@ def test_concurrent_source_yields_the_same_messages_as_abstract_source_when_an_e messages_from_abstract_source = _read_from_source(source, logger, config, catalog, state, AirbyteTracedException) messages_from_concurrent_source = _read_from_source(concurrent_source, logger, config, catalog, state, RuntimeError) - expected_messages = [ - AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - emitted_at=1577836800000.0, - error=None, - estimate=None, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.STARTED) - ), - ), - ), - AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - emitted_at=1577836800000.0, - error=None, - estimate=None, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.RUNNING) - ), - ), - ), - AirbyteMessage( - type=MessageType.RECORD, - record=AirbyteRecordMessage( - stream="stream0", - data=records[0], - emitted_at=1577836800000, - ), - ), - AirbyteMessage( - type=MessageType.TRACE, - trace=AirbyteTraceMessage( - type=TraceType.STREAM_STATUS, - emitted_at=1577836800000.0, - error=None, - estimate=None, - stream_status=AirbyteStreamStatusTraceMessage( - stream_descriptor=StreamDescriptor(name="stream0"), status=AirbyteStreamStatus(AirbyteStreamStatus.INCOMPLETE) - ), - ), - ), - ] - _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source) + _assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source) + _assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source) + _assert_errors(messages_from_abstract_source, messages_from_concurrent_source) + + +def _assert_status_messages(messages_from_abstract_source, messages_from_concurrent_source): + status_from_concurrent_source = [message for message in messages_from_concurrent_source if message.type == MessageType.TRACE and message.trace.type == TraceType.STREAM_STATUS] + + assert status_from_concurrent_source + _verify_messages( + [message for message in messages_from_abstract_source if message.type == MessageType.TRACE and message.trace.type == TraceType.STREAM_STATUS], + status_from_concurrent_source, + ) + + +def _assert_record_messages(messages_from_abstract_source, messages_from_concurrent_source): + records_from_concurrent_source = [message for message in messages_from_concurrent_source if message.type == MessageType.RECORD] + + assert records_from_concurrent_source + _verify_messages( + [message for message in messages_from_abstract_source if message.type == MessageType.RECORD], + records_from_concurrent_source, + ) + + +def _assert_errors(messages_from_abstract_source, messages_from_concurrent_source): + errors_from_concurrent_source = [message for message in messages_from_concurrent_source if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR] + errors_from_abstract_source = [message for message in messages_from_abstract_source if message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR] + + assert errors_from_concurrent_source + # exceptions might differ from both framework hence we only assert the count + assert len(errors_from_concurrent_source) == len(errors_from_abstract_source) def _init_logger(): @@ -442,7 +384,7 @@ def _read_from_source(source, logger, config, catalog, state, expected_exception return messages -def _verify_messages(expected_messages, messages_from_abstract_source, messages_from_concurrent_source): +def _verify_messages(expected_messages, messages_from_concurrent_source): assert _compare(expected_messages, messages_from_concurrent_source)